Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c9cd53aa
Unverified
Commit
c9cd53aa
authored
Oct 13, 2021
by
chenbohua3
Committed by
GitHub
Oct 13, 2021
Browse files
support dtype&scheme customization for QAT quantizer (#4137)
parent
b0f34da1
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
775 additions
and
190 deletions
+775
-190
docs/en_US/Compression/CustomizeCompressor.rst
docs/en_US/Compression/CustomizeCompressor.rst
+2
-2
examples/model_compress/quantization/QAT_torch_quantizer.py
examples/model_compress/quantization/QAT_torch_quantizer.py
+29
-16
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+226
-122
nni/common/version.py
nni/common/version.py
+7
-0
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+77
-31
nni/compression/pytorch/quantization/literal.py
nni/compression/pytorch/quantization/literal.py
+65
-0
nni/compression/pytorch/quantization/observers.py
nni/compression/pytorch/quantization/observers.py
+15
-0
nni/compression/pytorch/quantization/settings.py
nni/compression/pytorch/quantization/settings.py
+118
-0
nni/compression/pytorch/quantization/utils.py
nni/compression/pytorch/quantization/utils.py
+83
-0
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+153
-19
No files found.
docs/en_US/Compression/CustomizeCompressor.rst
View file @
c9cd53aa
...
@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
...
@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
grad_output : Tensor
grad_output : Tensor
gradient of the output of quantization operation
gradient of the output of quantization operation
quant_type : QuantType
quant_type : QuantType
the type of quantization, it can be `QuantType.
QUANT_
INPUT`, `QuantType.
QUANT_
WEIGHT`, `QuantType.
QUANT_
OUTPUT`,
the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`, `QuantType.OUTPUT`,
you can define different behavior for different types.
you can define different behavior for different types.
Returns
Returns
-------
-------
...
@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
...
@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
"""
"""
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
if quant_type == QuantType.
QUANT_
OUTPUT:
if quant_type == QuantType.OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0
grad_output[torch.abs(tensor) > 1] = 0
return grad_output
return grad_output
...
...
examples/model_compress/quantization/QAT_torch_quantizer.py
View file @
c9cd53aa
...
@@ -2,11 +2,13 @@ import torch
...
@@ -2,11 +2,13 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.compression.pytorch.quantization.settings
import
set_quant_scheme_dtype
import
sys
import
sys
sys
.
path
.
append
(
'../models'
)
sys
.
path
.
append
(
'../models'
)
from
mnist.naive
import
NaiveModel
from
mnist.naive
import
NaiveModel
def
train
(
model
,
device
,
train_loader
,
optimizer
):
def
train
(
model
,
device
,
train_loader
,
optimizer
):
model
.
train
()
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
...
@@ -68,13 +70,23 @@ def main():
...
@@ -68,13 +70,23 @@ def main():
},
{
},
{
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'fc1'
],
'op_names'
:
[
'fc1'
,
'fc2'
],
},
{
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'fc2'
],
}]
}]
# you can also set the quantization dtype and scheme layer-wise through configure_list like:
# configure_list = [{
# 'quant_types': ['weight', 'input'],
# 'quant_bits': {'weight': 8, 'input': 8},
# 'op_names': ['conv1', 'conv2'],
# 'quant_dtype': 'int',
# 'quant_scheme': 'per_channel_symmetric'
# }]
# For now quant_dtype's options are 'int' and 'uint. And quant_scheme's options are per_tensor_affine,
# per_tensor_symmetric, per_channel_affine and per_channel_symmetric.
set_quant_scheme_dtype
(
'weight'
,
'per_channel_symmetric'
,
'int'
)
set_quant_scheme_dtype
(
'output'
,
'per_tensor_symmetric'
,
'int'
)
set_quant_scheme_dtype
(
'input'
,
'per_tensor_symmetric'
,
'int'
)
model
=
NaiveModel
().
to
(
device
)
model
=
NaiveModel
().
to
(
device
)
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
).
to
(
device
)
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
).
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
...
@@ -98,5 +110,6 @@ def main():
...
@@ -98,5 +110,6 @@ def main():
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
print
(
"Generated calibration config is: "
,
calibration_config
)
print
(
"Generated calibration config is: "
,
calibration_config
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
c9cd53aa
This diff is collapsed.
Click to expand it.
nni/common/version.py
0 → 100644
View file @
c9cd53aa
import
logging
try
:
import
torch
TORCH_VERSION
=
tuple
(
int
(
x
)
for
x
in
torch
.
__version__
.
split
(
"."
)[:
2
])
except
Exception
:
logging
.
info
(
"PyTorch is not installed."
)
TORCH_VERSION
=
None
nni/compression/pytorch/compressor.py
View file @
c9cd53aa
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
copy
import
types
import
types
import
logging
import
logging
import
torch
import
torch
from
nni.common.graph_utils
import
build_module_graph
from
nni.common.graph_utils
import
build_module_graph
from
nni.compression.pytorch.quantization.literal
import
QuantType
,
BN_FOLD_OP
,
BN_FOLD_TAG
from
nni.compression.pytorch.quantization.observers
import
RecordingObserver
from
.
import
default_layers
from
.
import
default_layers
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
assert
len
(
inputs
)
==
1
,
"Quantization of input only supports ops with single input."
assert
len
(
inputs
)
==
1
,
"Quantization of input only supports ops with single input."
new_inp
=
self
.
quantizer
.
quant_grad
(
new_inp
=
self
.
quantizer
.
quant_grad
(
inputs
[
0
],
inputs
[
0
],
QuantType
.
QUANT_
INPUT
,
QuantType
.
INPUT
,
self
)
self
)
inputs
=
(
new_inp
,)
inputs
=
(
new_inp
,)
...
@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quant_grad
(
self
.
quantizer
.
quant_grad
(
new_weight
,
new_weight
,
QuantType
.
QUANT_
WEIGHT
,
QuantType
.
WEIGHT
,
self
,
inputs
[
0
])
self
,
inputs
[
0
])
result
=
self
.
module
(
*
inputs
)
result
=
self
.
module
(
*
inputs
)
...
@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
if
'output'
in
self
.
config
[
'quant_types'
]:
if
'output'
in
self
.
config
[
'quant_types'
]:
result
=
self
.
quantizer
.
quant_grad
(
result
=
self
.
quantizer
.
quant_grad
(
result
,
result
,
QuantType
.
QUANT_
OUTPUT
,
QuantType
.
OUTPUT
,
self
)
self
)
return
result
return
result
...
@@ -604,10 +607,13 @@ class Quantizer(Compressor):
...
@@ -604,10 +607,13 @@ class Quantizer(Compressor):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
):
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
):
model
=
model
.
module
model
=
model
.
module
model_copied
=
copy
.
deepcopy
(
model
)
self
.
identity_wrappers
=
[]
self
.
identity_wrappers
=
[]
self
.
conv_bn_patterns
=
{}
self
.
conv_bn_patterns
=
{}
self
.
find_conv_bn_patterns
(
model
,
dummy_input
)
self
.
find_conv_bn_patterns
(
model
,
dummy_input
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
all_shapes
=
{}
self
.
record_shape
(
model_copied
,
dummy_input
)
self
.
quant_grad
=
QuantGrad
.
apply
self
.
quant_grad
=
QuantGrad
.
apply
if
self
.
optimizer
is
not
None
:
if
self
.
optimizer
is
not
None
:
self
.
patch_optimizer
(
self
.
step_with_optimizer
)
self
.
patch_optimizer
(
self
.
step_with_optimizer
)
...
@@ -845,25 +851,54 @@ class Quantizer(Compressor):
...
@@ -845,25 +851,54 @@ class Quantizer(Compressor):
if
successor
.
op_type
==
'BatchNorm2d'
:
if
successor
.
op_type
==
'BatchNorm2d'
:
self
.
conv_bn_patterns
[
node_group
.
name
]
=
successor
.
name
self
.
conv_bn_patterns
[
node_group
.
name
]
=
successor
.
name
def
step_with_optimizer
(
self
):
def
record_shape
(
self
,
model
,
dummy_input
):
pass
class
QuantType
:
"""
"""
Enum class for quantization type.
Record input/output's shapes of each module to be quantized
Parameters
----------
model : torch.nn.Module
model to be recorded.
dummy_input : tupel of torch.tensor
inputs to the model.
"""
"""
QUANT_INPUT
=
0
def
_pre_forward_hook
(
self
,
inp
):
QUANT_WEIGHT
=
1
# Only record the first tensor of the input
QUANT_OUTPUT
=
2
return
self
.
pre_forward
(
inp
[
0
])
def
_post_forward_hook
(
self
,
_
,
out
):
return
self
.
post_forward
(
out
)
QType_Dict
=
{
if
dummy_input
is
None
:
0
:
"input"
,
return
1
:
"weight"
,
2
:
"output"
all_handles
=
[]
}
all_observers
=
{}
modules_to_compress
=
self
.
get_modules_to_compress
()
compress_names
=
[
layer_info
[
0
].
name
for
layer_info
in
modules_to_compress
]
for
name
,
module
in
model
.
named_modules
():
if
name
in
compress_names
:
all_observers
[
name
]
=
{}
all_observers
[
name
][
'input_hook'
]
=
RecordingObserver
()
all_observers
[
name
][
'output_hook'
]
=
RecordingObserver
()
module
.
add_module
(
'pre_forward'
,
all_observers
[
name
][
'input_hook'
])
module
.
add_module
(
'post_forward'
,
all_observers
[
name
][
'output_hook'
])
all_handles
.
append
(
module
.
register_forward_pre_hook
(
_pre_forward_hook
))
all_handles
.
append
(
module
.
register_forward_hook
(
_post_forward_hook
))
model
(
dummy_input
)
for
name
,
hooks
in
all_observers
.
items
():
# only support single input
input_val
=
hooks
[
'input_hook'
].
tensor_val
input_shape
=
input_val
[
0
].
shape
if
input_val
else
None
output_val
=
hooks
[
'output_hook'
].
tensor_val
output_shape
=
output_val
[
0
].
shape
if
output_val
else
None
shapes
=
[
input_shape
,
output_shape
]
self
.
all_shapes
[
name
]
=
shapes
return
def
step_with_optimizer
(
self
):
pass
BN_FOLD_OP
=
[
"Conv2d"
]
BN_FOLD_TAG
=
'BN_FOLD_TAG'
class
QuantGrad
(
torch
.
autograd
.
Function
):
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
"""
...
@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function):
...
@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function):
grad_output : Tensor
grad_output : Tensor
gradient of the output of quantization operation
gradient of the output of quantization operation
scale : Tensor
scale : Tensor
the type of quantization, it can be `QuantType.
QUANT_
INPUT`, `QuantType.
QUANT_
WEIGHT`,
the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`,
`QuantType.
QUANT_
OUTPUT`, you can define different behavior for different types.
`QuantType.OUTPUT`, you can define different behavior for different types.
zero_point : Tensor
zero_point : Tensor
zero_point for quantizing tensor
zero_point for quantizing tensor
qmin : Tensor
qmin : Tensor
...
@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function):
...
@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function):
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
output
=
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
,
**
kwargs
)
output
=
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
,
**
kwargs
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
QType_Dict
[
quant_type
])
if
hasattr
(
wrapper
.
module
,
"layer_quant_setting"
):
qmin
,
qmax
=
torch
.
Tensor
([
0
]).
to
(
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
to
(
tensor
.
device
)
layer_quant_setting
=
wrapper
.
module
.
layer_quant_setting
if
hasattr
(
wrapper
.
module
,
'scale'
)
and
hasattr
(
wrapper
.
module
,
'zero_point'
):
qmin
,
qmax
=
getattr
(
layer_quant_setting
,
quant_type
).
get_qmin_qmax
()
else
:
# todo: when dtype/scheme customization is ready for all quantizers, remove this
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
quant_type
)
qmin
,
qmax
=
0
,
(
1
<<
bits
)
-
1
scale_name
,
zero_point_name
=
quant_type
.
type_to_scale_zero_point_name
()
if
hasattr
(
wrapper
.
module
,
scale_name
)
and
hasattr
(
wrapper
.
module
,
zero_point_name
):
scale
=
getattr
(
wrapper
.
module
,
scale_name
)
zero_point
=
getattr
(
wrapper
.
module
,
zero_point_name
)
# todo: remove this when other quantizers use different scale & zero point for input/weight/output
elif
hasattr
(
wrapper
.
module
,
'scale'
)
and
hasattr
(
wrapper
.
module
,
'zero_point'
):
scale
=
wrapper
.
module
.
scale
scale
=
wrapper
.
module
.
scale
zero_point
=
wrapper
.
module
.
zero_point
zero_point
=
wrapper
.
module
.
zero_point
else
:
else
:
scale
,
zero_point
=
None
,
None
scale
,
zero_point
=
None
,
None
ctx
.
save_for_backward
(
tensor
)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
# Others should directly assign to ctx.
ctx
.
scale
=
scale
ctx
.
save_for_backward
(
tensor
)
ctx
.
zero_point
=
zero_point
ctx
.
quant_type
=
quant_type
ctx
.
quant_type
=
quant_type
ctx
.
qmin
,
ctx
.
qmax
=
qmin
,
qmax
ctx
.
qmin
,
ctx
.
qmax
=
qmin
,
qmax
ctx
.
scale
=
scale
ctx
.
zero_point
=
zero_point
return
output
return
output
@
classmethod
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
=
ctx
.
saved_variables
[
0
]
tensor
=
ctx
.
saved_variables
[
0
]
scale
,
zero_point
=
ctx
.
scale
,
ctx
.
zero_point
scale
,
zero_point
=
ctx
.
scale
,
ctx
.
zero_point
qmin
,
qmax
=
ctx
.
qmin
,
ctx
.
qmax
quant_type
=
ctx
.
quant_type
quant_type
=
ctx
.
quant_type
qmin
,
qmax
=
ctx
.
qmin
,
ctx
.
qmax
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
)
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
)
return
output
,
None
,
None
,
None
return
output
,
None
,
None
,
None
...
@@ -977,11 +1023,11 @@ def _check_bias(module):
...
@@ -977,11 +1023,11 @@ def _check_bias(module):
return
False
return
False
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
if
quant_type
==
QuantType
.
QUANT_
INPUT
:
if
quant_type
==
QuantType
.
INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_
WEIGHT
:
elif
quant_type
==
QuantType
.
WEIGHT
:
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_
OUTPUT
:
elif
quant_type
==
QuantType
.
OUTPUT
:
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
else
:
raise
ValueError
(
"unrecognized QuantType."
)
raise
ValueError
(
"unrecognized QuantType."
)
...
...
nni/compression/pytorch/quantization/literal.py
0 → 100644
View file @
c9cd53aa
from
enum
import
Enum
,
EnumMeta
class
_QuantLiteralEnumMeta
(
EnumMeta
):
def
__contains__
(
cls
,
item
):
try
:
cls
(
item
)
except
ValueError
:
return
False
return
True
class
_QuantLiteralEnum
(
Enum
,
metaclass
=
_QuantLiteralEnumMeta
):
pass
class
QuantScheme
(
str
,
_QuantLiteralEnum
):
PER_TENSOR_AFFINE
=
'per_tensor_affine'
PER_TENSOR_SYMMETRIC
=
'per_tensor_symmetric'
PER_CHANNEL_AFFINE
=
'per_channel_affine'
PER_CHANNEL_SYMMETRIC
=
'per_channel_symmetric'
PER_CHANNEL_QUANT_SCHEME
=
[
QuantScheme
.
PER_CHANNEL_AFFINE
,
QuantScheme
.
PER_CHANNEL_SYMMETRIC
]
class
QuantDtype
(
str
,
_QuantLiteralEnum
):
UINT
=
'uint'
INT
=
'int'
class
QuantType
(
str
,
_QuantLiteralEnum
):
INPUT
=
'input'
WEIGHT
=
'weight'
OUTPUT
=
'output'
def
type_to_scale_zero_point_name
(
self
):
if
self
==
QuantType
.
INPUT
:
return
'input_scale'
,
'input_zero_point'
elif
self
==
QuantType
.
WEIGHT
:
return
'weight_scale'
,
'weight_zero_point'
elif
self
==
QuantType
.
OUTPUT
:
return
'output_scale'
,
'output_zero_point'
else
:
raise
TypeError
# Just show each attribute's name, no practical effect
class
QuantConfigLiteral
(
str
,
_QuantLiteralEnum
):
QUANT_SETTINGS
=
'quant_settings'
QUANT_SCHEME
=
'quant_scheme'
QUANT_DTYPE
=
'quant_dtype'
BITS
=
'bits'
QMIN
=
'qmin'
QMAX
=
'qmax'
INPUT_SCALE
=
'input_scale'
INPUT_ZERO_POINT
=
'input_zero_point'
OUTPUT_SCALE
=
'output_scale'
OUTPUT_ZERO_POINT
=
'output_zero_point'
WEIGHT_SCALE
=
'weight_scale'
WEIGHT_ZERO_POINT
=
'weight_zero_point'
BN_FOLD_OP
=
[
"Conv2d"
]
BN_FOLD_TAG
=
'BN_FOLD_TAG'
nni/
algorithms/
compression/pytorch/quantization/observers.py
→
nni/compression/pytorch/quantization/observers.py
View file @
c9cd53aa
from
torch.quantization
import
default_weight_observer
,
default_histogram_observer
from
torch.quantization
import
default_weight_observer
,
default_histogram_observer
from
torch.quantization
import
RecordingObserver
as
_RecordingObserver
__all__
=
[
"default_weight_observer"
,
"default_histogram_observer"
]
__all__
=
[
"default_weight_observer"
,
"default_histogram_observer"
,
"RecordingObserver"
]
class
RecordingObserver
(
_RecordingObserver
):
"""
A extended version of PyTorch's RecordingObserver, used to record gpu tensor
"""
def
forward
(
self
,
x
):
val
=
x
.
cpu
()
super
().
forward
(
val
)
return
x
nni/compression/pytorch/quantization/settings.py
0 → 100644
View file @
c9cd53aa
from
typing
import
Any
,
Optional
from
.literal
import
QuantDtype
,
QuantType
,
QuantScheme
from
.utils
import
calculate_qmin_qmax
,
get_bits_length
# default settings for quantization module
quant_default_settings
=
{
QuantType
.
WEIGHT
:
{
'quant_scheme'
:
QuantScheme
.
PER_TENSOR_AFFINE
,
'quant_dtype'
:
QuantDtype
.
UINT
,
},
QuantType
.
INPUT
:
{
'quant_scheme'
:
QuantScheme
.
PER_TENSOR_AFFINE
,
'quant_dtype'
:
QuantDtype
.
UINT
},
QuantType
.
OUTPUT
:
{
'quant_scheme'
:
QuantScheme
.
PER_TENSOR_AFFINE
,
'quant_dtype'
:
QuantDtype
.
UINT
}
}
class
TensorQuantSetting
(
object
):
def
__init__
(
self
,
**
kwargs
):
self
.
_fields
=
{}
for
k
,
v
in
kwargs
.
items
():
self
.
_fields
[
k
]
=
v
def
__setattr__
(
self
,
name
:
str
,
val
:
Any
)
->
None
:
if
name
.
startswith
(
"_"
):
super
().
__setattr__
(
name
,
val
)
else
:
self
.
_fields
[
name
]
=
val
def
__getattr__
(
self
,
name
):
if
name
==
"_fields"
or
name
not
in
self
.
_fields
:
raise
AttributeError
(
"Cannot find {} in TensorQuantSetting!"
.
format
(
name
))
return
self
.
_fields
[
name
]
def
get_qmin_qmax
(
self
):
assert
'qmin'
in
self
.
_fields
and
'qmax'
in
self
.
_fields
,
\
"Can not found qmin & qmax in TensorQuantSetting"
return
self
.
_fields
[
'qmin'
],
self
.
_fields
[
'qmax'
]
class
LayerQuantSetting
(
object
):
def
__init__
(
self
,
config
):
self
.
input
:
Optional
[
TensorQuantSetting
]
=
None
self
.
weight
:
Optional
[
TensorQuantSetting
]
=
None
self
.
output
:
Optional
[
TensorQuantSetting
]
=
None
self
.
_extra_layer_setting
=
{}
for
quant_type
in
QuantType
:
if
quant_type
in
config
.
get
(
"quant_types"
,
[]):
setting
=
TensorQuantSetting
()
quant_scheme
=
self
.
parse_optional_config
(
config
,
quant_type
,
'quant_scheme'
)
setting
.
quant_scheme
=
quant_scheme
quant_dtype
=
self
.
parse_optional_config
(
config
,
quant_type
,
'quant_dtype'
)
setting
.
quant_dtype
=
quant_dtype
bits
=
get_bits_length
(
config
,
quant_type
)
qmin
,
qmax
=
calculate_qmin_qmax
(
bits
,
quant_dtype
)
setting
.
bits
=
bits
setting
.
qmin
=
qmin
setting
.
qmax
=
qmax
setattr
(
self
,
quant_type
,
setting
)
def
__setattr__
(
self
,
name
:
str
,
val
:
Any
)
->
None
:
if
name
.
startswith
(
"_"
)
or
name
in
QuantType
:
super
().
__setattr__
(
name
,
val
)
else
:
self
.
_extra_layer_setting
[
name
]
=
val
def
__getattr__
(
self
,
name
):
if
name
==
"_extra_layer_setting"
or
name
not
in
self
.
_extra_layer_setting
:
raise
AttributeError
(
"Cannot find {} in LayerQuantSetting!"
.
format
(
name
))
return
self
.
_extra_layer_setting
[
name
]
@
staticmethod
def
parse_optional_config
(
config
,
quant_type
,
target
):
def
get_config
(
config
,
quant_type
,
target
):
if
not
config
.
get
(
target
):
return
None
if
isinstance
(
config
[
target
],
dict
):
return
config
[
target
].
get
(
quant_type
)
else
:
return
config
[
target
]
default_val
=
quant_default_settings
[
quant_type
].
get
(
target
,
None
)
config_val
=
get_config
(
config
,
quant_type
,
target
)
val
=
config_val
if
config_val
else
default_val
return
val
def
set_quant_scheme_dtype
(
quant_type
,
new_scheme
=
None
,
new_dtype
=
None
):
# todo: remove this if we convert string config to enum type.
if
isinstance
(
quant_type
,
str
):
assert
quant_type
in
QuantType
,
"Wrong quant_type"
if
isinstance
(
new_scheme
,
str
):
assert
new_scheme
in
QuantScheme
,
"Wrong quant_scheme"
if
isinstance
(
new_dtype
,
str
):
assert
new_dtype
in
QuantDtype
,
"Wrong quant_dtype"
# TODO: It is not a good idea to directly modify global settings. A better choice is
# making this function an attribute function of Quantizer and call this function after
# the quantizer is initialized. However, within current framework of quantization, if
# we want to modify the dtype & scheme when the quantizer is initialized, we must do
# some other things (like changing the shapes of scales and zero_points and other quantization
# information in the subclass).
global
quant_default_settings
if
new_scheme
is
not
None
:
quant_default_settings
[
quant_type
][
'quant_scheme'
]
=
new_scheme
if
new_dtype
is
not
None
:
quant_default_settings
[
quant_type
][
'quant_dtype'
]
=
new_dtype
return
nni/compression/pytorch/quantization/utils.py
0 → 100644
View file @
c9cd53aa
import
torch
from
nni.common.version
import
TORCH_VERSION
from
.literal
import
QuantDtype
,
QuantScheme
,
QuantType
def
calculate_qmin_qmax
(
bits
,
dtype
):
if
dtype
==
QuantDtype
.
INT
:
qmin
,
qmax
=
-
2
**
(
bits
-
1
)
+
1
,
2
**
(
bits
-
1
)
-
1
elif
dtype
==
QuantDtype
.
UINT
:
qmin
,
qmax
=
0
,
2
**
bits
-
1
else
:
raise
TypeError
(
"Wrong quantization dtype, please make sure it is one of 'int' and 'uint'."
)
return
qmin
,
qmax
def
get_bits_length
(
config
,
quant_type
):
if
isinstance
(
config
[
"quant_bits"
],
int
):
return
config
[
"quant_bits"
]
else
:
return
config
[
"quant_bits"
].
get
(
quant_type
)
def
get_target_dim
(
quant_type
,
quant_scheme
):
# for weight: c_out x c_in x (h) * (w)
# for feature maps: batch * channel * (t) * h * w
# other type is not supported for now
default_idx
=
0
if
quant_type
==
QuantType
.
WEIGHT
else
1
if
is_per_channel
(
quant_scheme
):
target_dim
=
default_idx
else
:
target_dim
=
None
return
target_dim
def
get_min_max_value
(
x
,
quant_type
,
quant_scheme
):
target_dim
=
get_target_dim
(
quant_type
,
quant_scheme
)
if
target_dim
is
None
:
return
torch
.
min
(
x
),
torch
.
max
(
x
)
indices
=
list
(
range
(
len
(
x
.
shape
)))
assert
target_dim
<
len
(
indices
),
"target_dim needs to be less than the number of dim of the tensor"
del
indices
[
target_dim
]
if
TORCH_VERSION
>
(
1
,
6
):
min_val
=
torch
.
amin
(
x
,
indices
,
keepdims
=
True
)
max_val
=
torch
.
amax
(
x
,
indices
,
keepdims
=
True
)
else
:
min_val
=
max_val
=
x
for
ind
in
indices
:
min_val
=
torch
.
min
(
min_val
,
dim
=
ind
,
keepdim
=
True
)[
0
]
max_val
=
torch
.
max
(
max_val
,
dim
=
ind
,
keepdim
=
True
)[
0
]
return
min_val
,
max_val
def
get_mean_value
(
x
,
target_dim
=
None
):
if
target_dim
is
None
:
return
torch
.
mean
(
x
)
indices
=
list
(
range
(
len
(
x
.
shape
)))
assert
target_dim
<
len
(
indices
),
"target_dim needs to be less than the number of dim of the tensor"
del
indices
[
target_dim
]
mean_val
=
torch
.
mean
(
x
,
dim
=
indices
,
keepdim
=
True
)
return
mean_val
def
is_per_channel
(
quant_scheme
):
if
quant_scheme
in
[
QuantScheme
.
PER_CHANNEL_AFFINE
,
QuantScheme
.
PER_CHANNEL_SYMMETRIC
]:
return
True
else
:
return
False
def
get_quant_shape
(
shape
,
quant_type
,
quant_scheme
):
default_idx
=
0
if
quant_type
==
QuantType
.
WEIGHT
else
1
if
is_per_channel
(
quant_scheme
):
quant_shape
=
[
1
if
idx
!=
default_idx
else
s
for
idx
,
s
in
enumerate
(
shape
)]
else
:
quant_shape
=
[]
return
quant_shape
test/ut/sdk/test_compressor_torch.py
View file @
c9cd53aa
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
import
schema
import
schema
import
nni.algorithms.compression.pytorch.pruning
as
torch_pruner
import
nni.algorithms.compression.pytorch.pruning
as
torch_pruner
import
nni.algorithms.compression.pytorch.quantization
as
torch_quantizer
import
nni.algorithms.compression.pytorch.quantization
as
torch_quantizer
from
nni.compression.pytorch.quantization.utils
import
calculate_qmin_qmax
,
get_quant_shape
,
get_min_max_value
import
math
import
math
...
@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase):
...
@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase):
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
)
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
,
dummy_input
=
dummy
)
quantizer
.
compress
()
quantizer
.
compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
...
@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase):
...
@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase):
self
.
assertFalse
(
isinstance
(
model
.
fc1
.
module
.
weight
,
torch
.
nn
.
Parameter
))
self
.
assertFalse
(
isinstance
(
model
.
fc1
.
module
.
weight
,
torch
.
nn
.
Parameter
))
self
.
assertFalse
(
isinstance
(
model
.
fc2
.
module
.
weight
,
torch
.
nn
.
Parameter
))
self
.
assertFalse
(
isinstance
(
model
.
fc2
.
module
.
weight
,
torch
.
nn
.
Parameter
))
def
test_quantization_dtype_scheme
(
self
):
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
2
,
3
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
2
)
def
forward
(
self
,
x
):
x
=
self
.
bn1
(
self
.
conv1
(
x
))
return
x
dtypes
=
[
'int'
,
'uint'
]
qschemes
=
[
'per_tensor_affine'
,
'per_tensor_symmetric'
,
'per_channel_affine'
,
'per_channel_symmetric'
]
for
dtype
in
dtypes
:
for
qscheme
in
qschemes
:
config_list
=
[{
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
8
,
'op_types'
:
[
'Conv2d'
],
'quant_dtype'
:
dtype
,
'quant_scheme'
:
qscheme
}]
model
=
TestModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
# only QAT_quantizer is supported for now
dummy
=
torch
.
randn
(
1
,
1
,
4
,
4
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
,
dummy_input
=
dummy
)
# test layer setting
for
layer
,
config
in
quantizer
.
modules_to_compress
:
module
=
layer
.
module
name
=
layer
.
name
layer_setting
=
module
.
layer_quant_setting
qmin
,
qmax
=
calculate_qmin_qmax
(
8
,
dtype
)
all_quant_types
=
[
'input'
,
'weight'
]
for
quant_type
in
all_quant_types
:
# check for settings
tensor_setting
=
getattr
(
layer_setting
,
quant_type
)
self
.
assertTrue
(
tensor_setting
is
not
None
)
self
.
assertTrue
(
tensor_setting
.
quant_scheme
==
qscheme
)
self
.
assertTrue
(
tensor_setting
.
quant_dtype
==
dtype
)
self
.
assertTrue
(
tensor_setting
.
qmin
==
qmin
)
self
.
assertTrue
(
tensor_setting
.
qmax
==
qmax
)
input_shape
,
output_shape
=
quantizer
.
all_shapes
[
name
]
shape
=
input_shape
if
quant_type
==
'input'
else
module
.
weight
.
shape
quant_shape
=
get_quant_shape
(
shape
,
quant_type
,
qscheme
)
scale_name
=
quant_type
+
'_scale'
zero_point_name
=
quant_type
+
'_zero_point'
scale
=
getattr
(
module
,
scale_name
)
zero_point
=
getattr
(
module
,
zero_point_name
)
self
.
assertTrue
(
list
(
scale
.
shape
)
==
quant_shape
)
self
.
assertTrue
(
list
(
zero_point
.
shape
)
==
quant_shape
)
weight
=
torch
.
arange
(
start
=
1
,
end
=
19
).
view
(
2
,
1
,
3
,
3
)
if
qscheme
==
'per_channel_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
9.
/
127
,
18.
/
127
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
2
,
1
,
1
,
1
])
*
0
else
:
target_scale
=
torch
.
tensor
([
9.
/
127.5
,
18.
/
127.5
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
2
,
1
,
1
,
1
])
*
127
elif
qscheme
==
'per_tensor_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
18.
/
127
)
target_zero_point
=
torch
.
zeros
([])
else
:
target_scale
=
torch
.
tensor
(
18.
/
127.5
)
target_zero_point
=
torch
.
ones
([])
*
127
elif
qscheme
==
'per_channel_affine'
:
min_val
=
torch
.
tensor
([
0.
,
0.
]).
view
([
2
,
1
,
1
,
1
])
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
9.
/
254
,
18.
/
254
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
-
127
-
torch
.
round
(
min_val
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
([
9.
/
255
,
18.
/
255
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
else
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
18.
/
254
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
(
18.
/
255
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
wrapper
=
getattr
(
model
,
name
)
wrapper
.
module
.
weight
=
weight
quantizer
.
quantize_weight
(
wrapper
)
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
weight_scale
,
target_scale
))
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
weight_zero_point
,
target_zero_point
))
inp
=
torch
.
arange
(
start
=
0
,
end
=
16
).
view
(
1
,
1
,
4
,
4
)
if
qscheme
==
'per_channel_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
15.
/
127
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
1
,
1
,
1
,
1
])
*
0
else
:
target_scale
=
torch
.
tensor
([
15.
/
127.5
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
1
,
1
,
1
,
1
])
*
127
elif
qscheme
==
'per_tensor_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
15.
/
127
)
target_zero_point
=
torch
.
zeros
([])
else
:
target_scale
=
torch
.
tensor
(
15.
/
127.5
)
target_zero_point
=
torch
.
ones
([])
*
127
elif
qscheme
==
'per_channel_affine'
:
min_val
=
torch
.
tensor
([
0.
]).
view
([
1
,
1
,
1
,
1
])
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
15.
/
254
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
-
127
-
torch
.
round
(
min_val
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
([
15.
/
255
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
else
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
15.
/
254
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
(
15.
/
255
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
quantizer
.
quantize_input
(
inp
,
wrapper
)
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
input_scale
,
target_scale
))
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
input_zero_point
,
target_zero_point
))
def
test_torch_QAT_quantizer
(
self
):
def
test_torch_QAT_quantizer
(
self
):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
config_list
=
[{
...
@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase):
...
@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase):
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
)
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
,
dummy_input
=
dummy
)
quantizer
.
compress
()
quantizer
.
compress
()
# test quantize
# test quantize
...
@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase):
...
@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase):
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
.
data
=
weight
model
.
conv2
.
module
.
weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
weight_
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
==
0
assert
model
.
conv2
.
module
.
weight_
zero_point
==
0
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
scale
,
torch
.
tensor
([
0.0
4
/
255
])))
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_
scale
,
torch
.
tensor
([
4
.
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
zero_point
,
torch
.
tensor
(
[
0.
]
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_
zero_point
,
torch
.
tensor
(
0.
)))
# range including 0
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
=
weight
model
.
conv2
.
module
.
weight
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
weight_
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
in
(
42
,
43
)
assert
model
.
conv2
.
module
.
weight_
zero_point
in
(
42
,
43
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
scale
,
torch
.
tensor
([
0.0796
/
255
])))
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_
scale
,
torch
.
tensor
([
4.
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
zero_point
,
torch
.
tensor
(
[
0.
]
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_
zero_point
,
torch
.
tensor
(
0.
)))
# test value of weight and bias after quantization
# test value of weight and bias after quantization
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
...
@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase):
...
@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase):
# test ema
# test ema
eps
=
1e-7
eps
=
1e-7
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
out
=
model
.
relu
(
x
)
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_output
,
0
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
0.
))
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_output
,
0.002
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
0.2
))
)
quantizer
.
step_with_optimizer
()
quantizer
.
step_with_optimizer
()
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
out
=
model
.
relu
(
x
)
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_output
,
0.002
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
0.002
))
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_output
,
0.00998
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
0.2060
))
)
def
test_torch_quantizer_export
(
self
):
def
test_torch_quantizer_export
(
self
):
config_list_qat
=
[{
config_list_qat
=
[{
...
@@ -424,11 +551,14 @@ class CompressorTestCase(TestCase):
...
@@ -424,11 +551,14 @@ class CompressorTestCase(TestCase):
}]
}]
config_set
=
[
config_list_qat
,
config_list_dorefa
,
config_list_bnn
]
config_set
=
[
config_list_qat
,
config_list_dorefa
,
config_list_bnn
]
quantize_algorithm_set
=
[
torch_quantizer
.
QAT_Quantizer
,
torch_quantizer
.
DoReFaQuantizer
,
torch_quantizer
.
BNNQuantizer
]
quantize_algorithm_set
=
[
torch_quantizer
.
QAT_Quantizer
,
torch_quantizer
.
DoReFaQuantizer
,
torch_quantizer
.
BNNQuantizer
]
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
for
config
,
quantize_algorithm
in
zip
(
config_set
,
quantize_algorithm_set
):
for
config
,
quantize_algorithm
in
zip
(
config_set
,
quantize_algorithm_set
):
model
=
TorchModel
()
model
=
TorchModel
()
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
if
quantize_algorithm
==
torch_quantizer
.
QAT_Quantizer
:
quantizer
=
quantize_algorithm
(
model
,
config
,
optimizer
,
dummy
)
else
:
quantizer
=
quantize_algorithm
(
model
,
config
,
optimizer
)
quantizer
=
quantize_algorithm
(
model
,
config
,
optimizer
)
quantizer
.
compress
()
quantizer
.
compress
()
...
@@ -461,6 +591,10 @@ class CompressorTestCase(TestCase):
...
@@ -461,6 +591,10 @@ class CompressorTestCase(TestCase):
model
=
TorchModel
().
eval
()
model
=
TorchModel
().
eval
()
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
if
quantize_algorithm
==
torch_quantizer
.
QAT_Quantizer
:
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
quantizer
=
quantize_algorithm
(
model
,
configure_list
,
optimizer
,
dummy_input
=
dummy
)
else
:
quantizer
=
quantize_algorithm
(
model
,
configure_list
,
optimizer
)
quantizer
=
quantize_algorithm
(
model
,
configure_list
,
optimizer
)
quantizer
.
compress
()
quantizer
.
compress
()
if
calibration_config
is
not
None
:
if
calibration_config
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment