Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
86335921
Unverified
Commit
86335921
authored
Aug 16, 2021
by
lin bin
Committed by
GitHub
Aug 16, 2021
Browse files
[Model Compression Quantization] Unify variable name (#3990)
parent
e5c3ac63
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
176 additions
and
166 deletions
+176
-166
examples/model_compress/quantization/mixed_precision_speedup_mnist.py
...el_compress/quantization/mixed_precision_speedup_mnist.py
+7
-5
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+84
-80
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+1
-1
nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py
...pression/pytorch/quantization_speedup/frontend_to_onnx.py
+10
-10
nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
...ssion/pytorch/quantization_speedup/integrated_tensorrt.py
+63
-63
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+11
-7
No files found.
examples/model_compress/quantization/mixed_precision_speedup_mnist.py
View file @
86335921
...
@@ -58,10 +58,10 @@ def post_training_quantization_example(train_loader, test_loader, device):
...
@@ -58,10 +58,10 @@ def post_training_quantization_example(train_loader, test_loader, device):
model
=
NaiveModel
()
model
=
NaiveModel
()
config
=
{
config
=
{
'conv1'
:{
'weight_bit'
:
8
,
'
activation
_bit'
:
8
},
'conv1'
:{
'weight_bit
s
'
:
8
,
'
output
_bit
s
'
:
8
},
'conv2'
:{
'weight_bit'
:
32
,
'
activation
_bit'
:
32
},
'conv2'
:{
'weight_bit
s
'
:
32
,
'
output
_bit
s
'
:
32
},
'fc1'
:{
'weight_bit'
:
16
,
'
activation
_bit'
:
16
},
'fc1'
:{
'weight_bit
s
'
:
16
,
'
output
_bit
s
'
:
16
},
'fc2'
:{
'weight_bit'
:
8
,
'
activation
_bit'
:
8
}
'fc2'
:{
'weight_bit
s
'
:
8
,
'
output
_bit
s
'
:
8
}
}
}
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
...
@@ -102,8 +102,10 @@ def quantization_aware_training_example(train_loader, test_loader, device):
...
@@ -102,8 +102,10 @@ def quantization_aware_training_example(train_loader, test_loader, device):
]
]
# finetune the model by using QAT
# finetune the model by using QAT
# enable batchnorm folding mode
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
QAT_Quantizer
(
model
,
configure_list
,
optimizer
)
quantizer
=
QAT_Quantizer
(
model
,
configure_list
,
optimizer
,
dummy_input
=
dummy_input
)
quantizer
.
compress
()
quantizer
.
compress
()
model
.
to
(
device
)
model
.
to
(
device
)
...
...
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
86335921
...
@@ -124,7 +124,7 @@ class QATGrad(QuantGrad):
...
@@ -124,7 +124,7 @@ class QATGrad(QuantGrad):
class
ObserverQuantizer
(
Quantizer
):
class
ObserverQuantizer
(
Quantizer
):
"""This quantizer uses observers to record weight/
activation
statistics to get quantization information.
"""This quantizer uses observers to record weight/
output
statistics to get quantization information.
The whole process can be divided into three steps:
The whole process can be divided into three steps:
1. It will register observers to the place where quantization would happen (just like registering hooks).
1. It will register observers to the place where quantization would happen (just like registering hooks).
2. The observers would record tensors' statistics during calibration.
2. The observers would record tensors' statistics during calibration.
...
@@ -140,7 +140,7 @@ class ObserverQuantizer(Quantizer):
...
@@ -140,7 +140,7 @@ class ObserverQuantizer(Quantizer):
# TODO:
# TODO:
# 1. support dtype and qscheme customization through config_list. Current settings:
# 1. support dtype and qscheme customization through config_list. Current settings:
# weight observer : per_tensor_symmetric, qint8
# weight observer : per_tensor_symmetric, qint8
#
activation
observer : per_tensor_affine, quint8, reduce_range=True
#
output
observer : per_tensor_affine, quint8, reduce_range=True
# 2. add more kinds of observers, such as Kullback-Leibler divergence.
# 2. add more kinds of observers, such as Kullback-Leibler divergence.
# 3. add batch normalization folding
# 3. add batch normalization folding
assert
not
model
.
training
,
"Currently the observer quantizer only works in evaluation mode."
assert
not
model
.
training
,
"Currently the observer quantizer only works in evaluation mode."
...
@@ -148,8 +148,8 @@ class ObserverQuantizer(Quantizer):
...
@@ -148,8 +148,8 @@ class ObserverQuantizer(Quantizer):
self
.
device
=
next
(
model
.
parameters
()).
device
self
.
device
=
next
(
model
.
parameters
()).
device
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
all_observers
=
defaultdict
(
dict
)
all_observers
=
defaultdict
(
dict
)
weight_q
_
min
,
weight_q
_
max
=
-
127
,
127
weight_qmin
,
weight_qmax
=
-
127
,
127
activation
_q
_
min
,
activation
_q
_
max
=
0
,
127
# reduce_range is set to True
output
_qmin
,
output
_qmax
=
0
,
127
# reduce_range is set to True
self
.
compressed
=
False
self
.
compressed
=
False
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
...
@@ -157,16 +157,16 @@ class ObserverQuantizer(Quantizer):
...
@@ -157,16 +157,16 @@ class ObserverQuantizer(Quantizer):
module
=
layer
.
module
module
=
layer
.
module
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
all_observers
[
layer_name
][
"weight"
]
=
default_weight_observer
()
all_observers
[
layer_name
][
"weight"
]
=
default_weight_observer
()
setattr
(
module
,
"weight_qmax"
,
weight_q
_
max
)
setattr
(
module
,
"weight_qmax"
,
weight_qmax
)
setattr
(
module
,
"weight_qmin"
,
weight_q
_
min
)
setattr
(
module
,
"weight_qmin"
,
weight_qmin
)
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
all_observers
[
layer_name
][
"input"
]
=
default_histogram_observer
()
all_observers
[
layer_name
][
"input"
]
=
default_histogram_observer
()
setattr
(
module
,
"input_qmax"
,
activation
_q
_
max
)
setattr
(
module
,
"input_qmax"
,
output
_qmax
)
setattr
(
module
,
"input_qmin"
,
activation
_q
_
min
)
setattr
(
module
,
"input_qmin"
,
output
_qmin
)
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
all_observers
[
layer_name
][
"output"
]
=
default_histogram_observer
()
all_observers
[
layer_name
][
"output"
]
=
default_histogram_observer
()
setattr
(
module
,
"output_qmax"
,
activation
_q
_
max
)
setattr
(
module
,
"output_qmax"
,
output
_qmax
)
setattr
(
module
,
"output_qmin"
,
activation
_q
_
min
)
setattr
(
module
,
"output_qmin"
,
output
_qmin
)
self
.
all_observers
=
all_observers
self
.
all_observers
=
all_observers
self
.
bound_model
.
to
(
self
.
device
)
self
.
bound_model
.
to
(
self
.
device
)
...
@@ -306,29 +306,29 @@ class ObserverQuantizer(Quantizer):
...
@@ -306,29 +306,29 @@ class ObserverQuantizer(Quantizer):
if
hasattr
(
module
,
'weight_scale'
)
or
hasattr
(
module
,
'input_scale'
)
or
hasattr
(
module
,
'output_scale'
):
if
hasattr
(
module
,
'weight_scale'
)
or
hasattr
(
module
,
'input_scale'
)
or
hasattr
(
module
,
'output_scale'
):
calibration_config
[
name
]
=
{}
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_scale'
):
if
hasattr
(
module
,
'weight_scale'
):
calibration_config
[
name
][
'weight_bit'
]
=
8
calibration_config
[
name
][
'weight_bit
s
'
]
=
8
val
=
float
(
module
.
weight_scale
*
module
.
weight_qmax
)
val
=
float
(
module
.
weight_scale
*
module
.
weight_qmax
)
calibration_config
[
name
][
'tracked_max_weight'
]
=
val
calibration_config
[
name
][
'tracked_max_weight'
]
=
val
calibration_config
[
name
][
'tracked_min_weight'
]
=
-
val
calibration_config
[
name
][
'tracked_min_weight'
]
=
-
val
calibration_config
[
name
][
'tracked_weight
_qmin
'
]
=
-
127
calibration_config
[
name
][
'tracked_
qmin_
weight'
]
=
-
127
calibration_config
[
name
][
'tracked_weight
_qmax
'
]
=
127
calibration_config
[
name
][
'tracked_
qmax_
weight'
]
=
127
# refactor these magic numbers when customizations of dtype and qscheme are ready.
# refactor these magic numbers when customizations of dtype and qscheme are ready.
if
hasattr
(
module
,
'input_scale'
):
if
hasattr
(
module
,
'input_scale'
):
calibration_config
[
name
][
'input_bit'
]
=
8
calibration_config
[
name
][
'input_bit
s
'
]
=
8
max_input
=
float
(
module
.
input_scale
*
(
module
.
input_qmax
-
module
.
input_zero_point
))
max_input
=
float
(
module
.
input_scale
*
(
module
.
input_qmax
-
module
.
input_zero_point
))
min_input
=
float
(
module
.
input_scale
*
(
module
.
input_qmin
-
module
.
input_zero_point
))
min_input
=
float
(
module
.
input_scale
*
(
module
.
input_qmin
-
module
.
input_zero_point
))
calibration_config
[
name
][
'tracked_min_input'
]
=
min_input
calibration_config
[
name
][
'tracked_min_input'
]
=
min_input
calibration_config
[
name
][
'tracked_max_input'
]
=
max_input
calibration_config
[
name
][
'tracked_max_input'
]
=
max_input
calibration_config
[
name
][
'tracked_input
_qmin
'
]
=
0
calibration_config
[
name
][
'tracked_
qmin_
input'
]
=
0
calibration_config
[
name
][
'tracked_input
_qmax
'
]
=
127
calibration_config
[
name
][
'tracked_
qmax_
input'
]
=
127
if
hasattr
(
module
,
'output_scale'
):
if
hasattr
(
module
,
'output_scale'
):
calibration_config
[
name
][
'
activation
_bit'
]
=
8
calibration_config
[
name
][
'
output
_bit
s
'
]
=
8
max_input
=
float
(
module
.
output_scale
*
(
module
.
output_qmax
-
module
.
output_zero_point
))
max_input
=
float
(
module
.
output_scale
*
(
module
.
output_qmax
-
module
.
output_zero_point
))
min_input
=
float
(
module
.
output_scale
*
(
module
.
output_qmin
-
module
.
output_zero_point
))
min_input
=
float
(
module
.
output_scale
*
(
module
.
output_qmin
-
module
.
output_zero_point
))
calibration_config
[
name
][
'tracked_min_
activation
'
]
=
min_input
calibration_config
[
name
][
'tracked_min_
output
'
]
=
min_input
calibration_config
[
name
][
'tracked_max_
activation
'
]
=
max_input
calibration_config
[
name
][
'tracked_max_
output
'
]
=
max_input
calibration_config
[
name
][
'tracked_
activation_qmin
'
]
=
0
calibration_config
[
name
][
'tracked_
qmin_output
'
]
=
0
calibration_config
[
name
][
'tracked_
activation_qmax
'
]
=
127
calibration_config
[
name
][
'tracked_
qmax_output
'
]
=
127
self
.
_del_simulated_attr
(
module
)
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
...
@@ -354,7 +354,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -354,7 +354,7 @@ class QAT_Quantizer(Quantizer):
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
dummy_input
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -370,7 +370,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -370,7 +370,7 @@ class QAT_Quantizer(Quantizer):
when the type is int, all quantization types share same bits length
when the type is int, all quantization types share same bits length
- quant_start_step : int
- quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where
activation
quantization ranges do not exclude a significant fraction of values, default value is 0
state where
output
quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
types of nn.module you want to apply quantization, eg. 'Conv2d'
- dummy_input : tuple of tensor
- dummy_input : tuple of tensor
...
@@ -379,6 +379,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -379,6 +379,7 @@ class QAT_Quantizer(Quantizer):
given, the batch normalization folding would be disabled.
given, the batch normalization folding would be disabled.
"""
"""
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"unrecognized optimizer type"
super
().
__init__
(
model
,
config_list
,
optimizer
,
dummy_input
)
super
().
__init__
(
model
,
config_list
,
optimizer
,
dummy_input
)
self
.
quant_grad
=
QATGrad
.
apply
self
.
quant_grad
=
QATGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
...
@@ -389,22 +390,22 @@ class QAT_Quantizer(Quantizer):
...
@@ -389,22 +390,22 @@ class QAT_Quantizer(Quantizer):
layer
.
module
.
register_buffer
(
"scale"
,
torch
.
Tensor
([
1.0
]))
layer
.
module
.
register_buffer
(
"scale"
,
torch
.
Tensor
([
1.0
]))
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'weight_bit
s
'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_input'
,
torch
.
zeros
(
1
))
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'
activation
_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'
output
_bit
s
'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_
activation
'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_
output
'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_
activation
'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_
output
'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
def
_del_simulated_attr
(
self
,
module
):
"""
"""
delete redundant parameters in quantize module
delete redundant parameters in quantize module
"""
"""
del_attr_list
=
[
'old_weight'
,
'old_bias'
,
'ema_decay'
,
'tracked_min_
activation
'
,
'tracked_max_
activation
'
,
del_attr_list
=
[
'old_weight'
,
'old_bias'
,
'ema_decay'
,
'tracked_min_
output
'
,
'tracked_max_
output
'
,
'tracked_min_input'
,
'tracked_max_input'
,
'scale'
,
'zero_point'
,
'weight_bit'
,
'tracked_min_input'
,
'tracked_max_input'
,
'scale'
,
'zero_point'
,
'weight_bit
s
'
,
'
activation
_bit'
,
'BN_FOLD_TAG'
]
'
output
_bit
s
'
,
'BN_FOLD_TAG'
]
for
attr
in
del_attr_list
:
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
delattr
(
module
,
attr
)
...
@@ -506,7 +507,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -506,7 +507,7 @@ class QAT_Quantizer(Quantizer):
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
weight
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
weight
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
weight
=
self
.
_dequantize
(
module
,
weight
)
weight
=
self
.
_dequantize
(
module
,
weight
)
module
.
weight_bit
=
torch
.
Tensor
([
weight_bits
])
module
.
weight_bit
s
=
torch
.
Tensor
([
weight_bits
])
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
return
weight
return
weight
...
@@ -514,23 +515,23 @@ class QAT_Quantizer(Quantizer):
...
@@ -514,23 +515,23 @@ class QAT_Quantizer(Quantizer):
config
=
wrapper
.
config
config
=
wrapper
.
config
module
=
wrapper
.
module
module
=
wrapper
.
module
output_bits
=
get_bits_length
(
config
,
'output'
)
output_bits
=
get_bits_length
(
config
,
'output'
)
module
.
activation
_bit
=
torch
.
Tensor
([
output_bits
])
module
.
output
_bit
s
=
torch
.
Tensor
([
output_bits
])
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
bound_model
.
steps
:
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_
activation
,
module
.
tracked_max_
activation
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_
output
,
module
.
tracked_max_
output
=
torch
.
min
(
output
),
torch
.
max
(
output
)
return
output
return
output
# we dont update output quantization parameters in evaluation stage
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_
activation
=
update_ema
(
module
.
tracked_min_
activation
,
current_min
,
module
.
tracked_min_
output
=
update_ema
(
module
.
tracked_min_
output
,
current_min
,
module
.
ema_decay
)
module
.
ema_decay
)
module
.
tracked_max_
activation
=
update_ema
(
module
.
tracked_max_
activation
,
current_max
,
module
.
tracked_max_
output
=
update_ema
(
module
.
tracked_max_
output
,
current_max
,
module
.
ema_decay
)
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min_
activation
,
module
.
tracked_max_
activation
)
output_bits
,
module
.
tracked_min_
output
,
module
.
tracked_max_
output
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
module
,
out
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
return
out
...
@@ -562,10 +563,10 @@ class QAT_Quantizer(Quantizer):
...
@@ -562,10 +563,10 @@ class QAT_Quantizer(Quantizer):
calibration_config
=
{}
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
hasattr
(
module
,
'weight_bit'
)
or
hasattr
(
module
,
'
activation
_bit'
):
if
hasattr
(
module
,
'weight_bit
s
'
)
or
hasattr
(
module
,
'
output
_bit
s
'
):
calibration_config
[
name
]
=
{}
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_bit'
):
if
hasattr
(
module
,
'weight_bit
s
'
):
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
calibration_config
[
name
][
'weight_bit
s
'
]
=
int
(
module
.
weight_bit
s
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
...
@@ -585,10 +586,10 @@ class QAT_Quantizer(Quantizer):
...
@@ -585,10 +586,10 @@ class QAT_Quantizer(Quantizer):
else
:
else
:
setattr
(
module
,
'bias'
,
None
)
setattr
(
module
,
'bias'
,
None
)
if
hasattr
(
module
,
'
activation
_bit'
):
if
hasattr
(
module
,
'
output
_bit
s
'
):
calibration_config
[
name
][
'
activation
_bit'
]
=
int
(
module
.
activation
_bit
)
calibration_config
[
name
][
'
output
_bit
s
'
]
=
int
(
module
.
output
_bit
s
)
calibration_config
[
name
][
'tracked_min_
activation
'
]
=
float
(
module
.
tracked_min_
activation
)
calibration_config
[
name
][
'tracked_min_
output
'
]
=
float
(
module
.
tracked_min_
output
)
calibration_config
[
name
][
'tracked_max_
activation
'
]
=
float
(
module
.
tracked_max_
activation
)
calibration_config
[
name
][
'tracked_max_
output
'
]
=
float
(
module
.
tracked_max_
output
)
self
.
_del_simulated_attr
(
module
)
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
...
@@ -642,20 +643,21 @@ class DoReFaQuantizer(Quantizer):
...
@@ -642,20 +643,21 @@ class DoReFaQuantizer(Quantizer):
(https://arxiv.org/abs/1606.06160)
(https://arxiv.org/abs/1606.06160)
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
):
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"unrecognized optimizer type"
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
device
=
next
(
model
.
parameters
()).
device
device
=
next
(
model
.
parameters
()).
device
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'weight_bit
s
'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
def
_del_simulated_attr
(
self
,
module
):
"""
"""
delete redundant parameters in quantize module
delete redundant parameters in quantize module
"""
"""
del_attr_list
=
[
'old_weight'
,
'weight_bit'
]
del_attr_list
=
[
'old_weight'
,
'weight_bit
s
'
]
for
attr
in
del_attr_list
:
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
delattr
(
module
,
attr
)
...
@@ -689,7 +691,7 @@ class DoReFaQuantizer(Quantizer):
...
@@ -689,7 +691,7 @@ class DoReFaQuantizer(Quantizer):
weight
=
self
.
quantize
(
weight
,
weight_bits
)
weight
=
self
.
quantize
(
weight
,
weight_bits
)
weight
=
2
*
weight
-
1
weight
=
2
*
weight
-
1
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight_bit
=
torch
.
Tensor
([
weight_bits
])
wrapper
.
module
.
weight_bit
s
=
torch
.
Tensor
([
weight_bits
])
# wrapper.module.weight.data = weight
# wrapper.module.weight.data = weight
return
weight
return
weight
...
@@ -725,9 +727,9 @@ class DoReFaQuantizer(Quantizer):
...
@@ -725,9 +727,9 @@ class DoReFaQuantizer(Quantizer):
calibration_config
=
{}
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
hasattr
(
module
,
'weight_bit'
):
if
hasattr
(
module
,
'weight_bit
s
'
):
calibration_config
[
name
]
=
{}
calibration_config
[
name
]
=
{}
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
calibration_config
[
name
][
'weight_bit
s
'
]
=
int
(
module
.
weight_bit
s
)
self
.
_del_simulated_attr
(
module
)
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
...
@@ -745,25 +747,26 @@ class ClipGrad(QuantGrad):
...
@@ -745,25 +747,26 @@ class ClipGrad(QuantGrad):
class
BNNQuantizer
(
Quantizer
):
class
BNNQuantizer
(
Quantizer
):
"""Binarized Neural Networks, as defined in:
"""Binarized Neural Networks, as defined in:
Binarized Neural Networks: Training Deep Neural Networks with Weights and
Activation
s Constrained to +1 or -1
Binarized Neural Networks: Training Deep Neural Networks with Weights and
Output
s Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
(https://arxiv.org/abs/1602.02830)
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
):
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"unrecognized optimizer type"
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
device
=
next
(
model
.
parameters
()).
device
device
=
next
(
model
.
parameters
()).
device
self
.
quant_grad
=
ClipGrad
.
apply
self
.
quant_grad
=
ClipGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'weight_bit
s
'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
def
_del_simulated_attr
(
self
,
module
):
"""
"""
delete redundant parameters in quantize module
delete redundant parameters in quantize module
"""
"""
del_attr_list
=
[
'old_weight'
,
'weight_bit'
]
del_attr_list
=
[
'old_weight'
,
'weight_bit
s
'
]
for
attr
in
del_attr_list
:
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
delattr
(
module
,
attr
)
...
@@ -796,7 +799,7 @@ class BNNQuantizer(Quantizer):
...
@@ -796,7 +799,7 @@ class BNNQuantizer(Quantizer):
# remove zeros
# remove zeros
weight
[
weight
==
0
]
=
1
weight
[
weight
==
0
]
=
1
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight_bit
=
torch
.
Tensor
([
1.0
])
wrapper
.
module
.
weight_bit
s
=
torch
.
Tensor
([
1.0
])
return
weight
return
weight
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
...
@@ -832,9 +835,9 @@ class BNNQuantizer(Quantizer):
...
@@ -832,9 +835,9 @@ class BNNQuantizer(Quantizer):
calibration_config
=
{}
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
hasattr
(
module
,
'weight_bit'
):
if
hasattr
(
module
,
'weight_bit
s
'
):
calibration_config
[
name
]
=
{}
calibration_config
[
name
]
=
{}
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
calibration_config
[
name
][
'weight_bit
s
'
]
=
int
(
module
.
weight_bit
s
)
self
.
_del_simulated_attr
(
module
)
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
...
@@ -848,7 +851,7 @@ class LsqQuantizer(Quantizer):
...
@@ -848,7 +851,7 @@ class LsqQuantizer(Quantizer):
https://arxiv.org/pdf/1902.08153.pdf
https://arxiv.org/pdf/1902.08153.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -864,10 +867,11 @@ class LsqQuantizer(Quantizer):
...
@@ -864,10 +867,11 @@ class LsqQuantizer(Quantizer):
when the type is int, all quantization types share same bits length
when the type is int, all quantization types share same bits length
- quant_start_step : int
- quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where
activation
quantization ranges do not exclude a significant fraction of values, default value is 0
state where
output
quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
"""
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"unrecognized optimizer type"
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
device
=
next
(
model
.
parameters
()).
device
device
=
next
(
model
.
parameters
()).
device
self
.
quant_grad
=
QuantForward
()
self
.
quant_grad
=
QuantForward
()
...
@@ -877,10 +881,10 @@ class LsqQuantizer(Quantizer):
...
@@ -877,10 +881,10 @@ class LsqQuantizer(Quantizer):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_parameter
(
"weight_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
layer
.
module
.
register_parameter
(
"weight_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
# todo: support per-channel quantization for weight since TensorRT use it for conv weight
# todo: support per-channel quantization for weight since TensorRT use it for conv weight
q_bit
=
get_bits_length
(
config
,
"weight"
)
q_bit
s
=
get_bits_length
(
config
,
"weight"
)
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
Tensor
([
q_bit
]))
layer
.
module
.
register_buffer
(
'weight_bit
s
'
,
torch
.
Tensor
([
q_bit
s
]))
qmax
=
2
**
(
q_bit
-
1
)
-
1
qmax
=
2
**
(
q_bit
s
-
1
)
-
1
qmin
=
-
2
**
(
q_bit
-
1
)
qmin
=
-
2
**
(
q_bit
s
-
1
)
init_weight_scale
=
layer
.
module
.
weight
.
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
init_weight_scale
=
layer
.
module
.
weight
.
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
layer
.
module
.
weight_scale
=
torch
.
nn
.
Parameter
(
init_weight_scale
)
layer
.
module
.
weight_scale
=
torch
.
nn
.
Parameter
(
init_weight_scale
)
layer
.
module
.
weight_qmax
=
qmax
layer
.
module
.
weight_qmax
=
qmax
...
@@ -889,12 +893,12 @@ class LsqQuantizer(Quantizer):
...
@@ -889,12 +893,12 @@ class LsqQuantizer(Quantizer):
self
.
optimizer
.
add_param_group
({
"params"
:
layer
.
module
.
weight_scale
})
self
.
optimizer
.
add_param_group
({
"params"
:
layer
.
module
.
weight_scale
})
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
# scale of
activation
will be initialized using the first batch data
# scale of
output
will be initialized using the first batch data
layer
.
module
.
register_parameter
(
"output_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
layer
.
module
.
register_parameter
(
"output_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
q_bit
=
get_bits_length
(
config
,
"output"
)
q_bit
s
=
get_bits_length
(
config
,
"output"
)
layer
.
module
.
register_buffer
(
'output_bit'
,
torch
.
Tensor
([
q_bit
]))
layer
.
module
.
register_buffer
(
'output_bit
s
'
,
torch
.
Tensor
([
q_bit
s
]))
qmax
=
2
**
(
q_bit
-
1
)
-
1
qmax
=
2
**
(
q_bit
s
-
1
)
-
1
qmin
=
-
2
**
(
q_bit
-
1
)
qmin
=
-
2
**
(
q_bit
s
-
1
)
layer
.
module
.
output_qmax
=
qmax
layer
.
module
.
output_qmax
=
qmax
layer
.
module
.
output_qmin
=
qmin
layer
.
module
.
output_qmin
=
qmin
...
@@ -903,10 +907,10 @@ class LsqQuantizer(Quantizer):
...
@@ -903,10 +907,10 @@ class LsqQuantizer(Quantizer):
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
# scale of input will be initialized using the first batch data
# scale of input will be initialized using the first batch data
layer
.
module
.
register_parameter
(
"input_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
layer
.
module
.
register_parameter
(
"input_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
q_bit
=
get_bits_length
(
config
,
"input"
)
q_bit
s
=
get_bits_length
(
config
,
"input"
)
layer
.
module
.
register_buffer
(
'input_bit'
,
torch
.
Tensor
([
q_bit
]))
layer
.
module
.
register_buffer
(
'input_bit
s
'
,
torch
.
Tensor
([
q_bit
s
]))
qmax
=
2
**
(
q_bit
-
1
)
-
1
qmax
=
2
**
(
q_bit
s
-
1
)
-
1
qmin
=
-
2
**
(
q_bit
-
1
)
qmin
=
-
2
**
(
q_bit
s
-
1
)
layer
.
module
.
input_qmax
=
qmax
layer
.
module
.
input_qmax
=
qmax
layer
.
module
.
input_qmin
=
qmin
layer
.
module
.
input_qmin
=
qmin
...
@@ -1011,18 +1015,18 @@ class LsqQuantizer(Quantizer):
...
@@ -1011,18 +1015,18 @@ class LsqQuantizer(Quantizer):
calibration_config
=
{}
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
hasattr
(
module
,
'input_bit'
)
or
hasattr
(
module
,
'output_bit'
):
if
hasattr
(
module
,
'input_bit
s
'
)
or
hasattr
(
module
,
'output_bit
s
'
):
calibration_config
[
name
]
=
{}
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_bit'
):
if
hasattr
(
module
,
'weight_bit
s
'
):
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
calibration_config
[
name
][
'weight_bit
s
'
]
=
int
(
module
.
weight_bit
s
)
abs_max_input
=
float
(
module
.
input_scale
*
module
.
input_qmax
)
abs_max_input
=
float
(
module
.
input_scale
*
module
.
input_qmax
)
calibration_config
[
name
][
'tracked_min_input'
]
=
-
abs_max_input
calibration_config
[
name
][
'tracked_min_input'
]
=
-
abs_max_input
calibration_config
[
name
][
'tracked_max_input'
]
=
abs_max_input
calibration_config
[
name
][
'tracked_max_input'
]
=
abs_max_input
if
hasattr
(
module
,
'output_bit'
):
if
hasattr
(
module
,
'output_bit
s
'
):
calibration_config
[
name
][
'
activation
_bit'
]
=
int
(
module
.
output_bit
)
calibration_config
[
name
][
'
output
_bit
s
'
]
=
int
(
module
.
output_bit
s
)
abs_max_output
=
float
(
module
.
output_scale
*
module
.
output_qmax
)
abs_max_output
=
float
(
module
.
output_scale
*
module
.
output_qmax
)
calibration_config
[
name
][
'tracked_min_
activation
'
]
=
-
abs_max_output
calibration_config
[
name
][
'tracked_min_
output
'
]
=
-
abs_max_output
calibration_config
[
name
][
'tracked_max_
activation
'
]
=
abs_max_output
calibration_config
[
name
][
'tracked_max_
output
'
]
=
abs_max_output
self
.
_del_simulated_attr
(
module
)
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
...
@@ -1034,8 +1038,8 @@ class LsqQuantizer(Quantizer):
...
@@ -1034,8 +1038,8 @@ class LsqQuantizer(Quantizer):
"""
"""
delete redundant parameters in quantize module
delete redundant parameters in quantize module
"""
"""
del_attr_list
=
[
'old_weight'
,
'tracked_min_input'
,
'tracked_max_input'
,
'tracked_min_
activation
'
,
\
del_attr_list
=
[
'old_weight'
,
'tracked_min_input'
,
'tracked_max_input'
,
'tracked_min_
output
'
,
\
'tracked_max_
activation
'
,
'output_scale'
,
'input_scale'
,
'weight_scale'
,
'weight_bit'
,
'output_bit'
,
'input_bit'
]
'tracked_max_
output
'
,
'output_scale'
,
'input_scale'
,
'weight_scale'
,
'weight_bit
s
'
,
'output_bit
s
'
,
'input_bit
s
'
]
for
attr
in
del_attr_list
:
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
delattr
(
module
,
attr
)
...
...
nni/compression/pytorch/compressor.py
View file @
86335921
...
@@ -834,7 +834,7 @@ class QuantGrad(torch.autograd.Function):
...
@@ -834,7 +834,7 @@ class QuantGrad(torch.autograd.Function):
@
classmethod
@
classmethod
def
get_bits_length
(
cls
,
config
,
quant_type
):
def
get_bits_length
(
cls
,
config
,
quant_type
):
"""
"""
Get bit for quantize config
Get bit
s
for quantize config
Parameters
Parameters
----------
----------
config : Dict
config : Dict
...
...
nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py
View file @
86335921
...
@@ -9,26 +9,26 @@ The main function of this page is to convert pytorch model to onnx model.
...
@@ -9,26 +9,26 @@ The main function of this page is to convert pytorch model to onnx model.
Convertion from pytorch model to onnx model is primary so that a critical
Convertion from pytorch model to onnx model is primary so that a critical
problem is caused that Layer name of pytorch model fail to convert to onnx
problem is caused that Layer name of pytorch model fail to convert to onnx
layer name directly. To solve it, we wrap pytorch model in new wrapper which
layer name directly. To solve it, we wrap pytorch model in new wrapper which
multiply bit number and input before computation of each op. Only in this
multiply bit
s
number and input before computation of each op. Only in this
way can onnx model get bit number of corresponded layer.
way can onnx model get bit
s
number of corresponded layer.
"""
"""
class
LayernameModuleWrapper
(
torch
.
nn
.
Module
):
class
LayernameModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_bit
)
->
None
:
def
__init__
(
self
,
module
,
module_bit
s
)
->
None
:
"""
"""
Parameters
Parameters
----------
----------
module : torch.nn.Module
module : torch.nn.Module
Layer module of pytorch model
Layer module of pytorch model
module_bit : int
module_bit
s
: int
Bit width setting for module
Bit
s
width setting for module
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
module
=
module
self
.
module_bit
=
module_bit
self
.
module_bit
s
=
module_bit
s
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
inputs
=
inputs
*
self
.
module_bit
inputs
=
inputs
*
self
.
module_bit
s
inputs
=
self
.
module
(
inputs
)
inputs
=
self
.
module
(
inputs
)
return
inputs
return
inputs
...
@@ -93,14 +93,14 @@ def unwrapper(model_onnx, index2name, config):
...
@@ -93,14 +93,14 @@ def unwrapper(model_onnx, index2name, config):
def
torch_to_onnx
(
model
,
config
,
input_shape
,
model_path
,
input_names
,
output_names
):
def
torch_to_onnx
(
model
,
config
,
input_shape
,
model_path
,
input_names
,
output_names
):
"""
"""
Convert torch model to onnx model and get layer bit config of onnx model.
Convert torch model to onnx model and get layer bit
s
config of onnx model.
Parameters
Parameters
----------
----------
model : pytorch model
model : pytorch model
The model to speed up by quantization
The model to speed up by quantization
config : dict
config : dict
Config recording bit number and name of layers
Config recording bit
s
number and name of layers
input_shape : tuple
input_shape : tuple
The input shape of model, shall pass it to torch.onnx.export
The input shape of model, shall pass it to torch.onnx.export
model_path : str
model_path : str
...
@@ -119,7 +119,7 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na
...
@@ -119,7 +119,7 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na
"""
"""
# Support Gemm, Conv, Relu, Clip(Relu6) and MaxPool
# Support Gemm, Conv, Relu, Clip(Relu6) and MaxPool
support_op
=
[
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Linear
,
torch
.
nn
.
ReLU
,
torch
.
nn
.
ReLU6
,
torch
.
nn
.
MaxPool2d
]
support_op
=
[
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Linear
,
torch
.
nn
.
ReLU
,
torch
.
nn
.
ReLU6
,
torch
.
nn
.
MaxPool2d
]
# Transfer bit number to onnx layer by using wrapper
# Transfer bit
s
number to onnx layer by using wrapper
index2name
=
{}
index2name
=
{}
name2index
=
{}
name2index
=
{}
if
config
is
not
None
:
if
config
is
not
None
:
...
...
nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
View file @
86335921
...
@@ -31,18 +31,18 @@ Precision_Dict = {
...
@@ -31,18 +31,18 @@ Precision_Dict = {
def
valid_config
(
config
=
None
):
def
valid_config
(
config
=
None
):
"""
"""
This function validates the bit setting configuration
This function validates the bit
s
setting configuration
"""
"""
if
config
is
None
:
if
config
is
None
:
return
return
support_bit
=
[
8
,
16
,
32
]
support_bit
s
=
[
8
,
16
,
32
]
for
name
in
config
.
keys
():
for
name
in
config
.
keys
():
if
'weight_bit'
in
config
[
name
]:
if
'weight_bit
s
'
in
config
[
name
]:
w_bit
=
config
[
name
][
'weight_bit'
]
w_bit
s
=
config
[
name
][
'weight_bit
s
'
]
assert
w_bit
in
support_bit
,
"weight bit should be 8, 16, 32"
assert
w_bit
s
in
support_bit
s
,
"weight bit
s
should be 8, 16, 32"
if
'
activation
_bit'
in
config
[
name
]:
if
'
output
_bit
s
'
in
config
[
name
]:
a_bit
=
config
[
name
][
'
activation
_bit'
]
a_bit
s
=
config
[
name
][
'
output
_bit
s
'
]
assert
a_bit
in
support_bit
,
"
activation
bit should be 8, 16, 32"
assert
a_bit
s
in
support_bit
s
,
"
output
bit
s
should be 8, 16, 32"
def
handle_gemm
(
network
,
layer_idx
,
config
):
def
handle_gemm
(
network
,
layer_idx
,
config
):
"""
"""
...
@@ -55,26 +55,26 @@ def handle_gemm(network, layer_idx, config):
...
@@ -55,26 +55,26 @@ def handle_gemm(network, layer_idx, config):
layer_idx : int
layer_idx : int
layer index of gemm
layer index of gemm
config : dict
config : dict
Config recording bit number and name of layers
Config recording bit
s
number and name of layers
"""
"""
layer
=
network
.
get_layer
(
layer_idx
)
layer
=
network
.
get_layer
(
layer_idx
)
pre_layer
=
network
.
get_layer
(
layer_idx
-
1
)
pre_layer
=
network
.
get_layer
(
layer_idx
-
1
)
next_layer
=
network
.
get_layer
(
layer_idx
+
1
)
next_layer
=
network
.
get_layer
(
layer_idx
+
1
)
# if weight bit exists, set three layers' precision,
# if weight bit
s
exists, set three layers' precision,
# input tensor range and the first two layers' output type
# input tensor range and the first two layers' output type
if
'weight_bit'
in
config
[
layer
.
name
]:
if
'weight_bit
s
'
in
config
[
layer
.
name
]:
assert
'tracked_min_input'
in
config
[
layer
.
name
]
assert
'tracked_min_input'
in
config
[
layer
.
name
]
assert
'tracked_max_input'
in
config
[
layer
.
name
]
assert
'tracked_max_input'
in
config
[
layer
.
name
]
w_bit
=
config
[
layer
.
name
][
'weight_bit'
]
w_bit
s
=
config
[
layer
.
name
][
'weight_bit
s
'
]
tracked_min_input
=
config
[
layer
.
name
][
'tracked_min_input'
]
tracked_min_input
=
config
[
layer
.
name
][
'tracked_min_input'
]
tracked_max_input
=
config
[
layer
.
name
][
'tracked_max_input'
]
tracked_max_input
=
config
[
layer
.
name
][
'tracked_max_input'
]
# set three layers the same precision
# set three layers the same precision
layer
.
precision
=
Precision_Dict
[
w_bit
]
layer
.
precision
=
Precision_Dict
[
w_bit
s
]
pre_layer
.
precision
=
Precision_Dict
[
w_bit
]
pre_layer
.
precision
=
Precision_Dict
[
w_bit
s
]
next_layer
.
precision
=
Precision_Dict
[
w_bit
]
next_layer
.
precision
=
Precision_Dict
[
w_bit
s
]
# set the first two layers' output type
# set the first two layers' output type
pre_layer
.
set_output_type
(
0
,
Precision_Dict
[
w_bit
])
pre_layer
.
set_output_type
(
0
,
Precision_Dict
[
w_bit
s
])
layer
.
set_output_type
(
0
,
Precision_Dict
[
w_bit
])
layer
.
set_output_type
(
0
,
Precision_Dict
[
w_bit
s
])
pre_in_tensor
=
pre_layer
.
get_input
(
0
)
pre_in_tensor
=
pre_layer
.
get_input
(
0
)
in_tensor
=
layer
.
get_input
(
0
)
in_tensor
=
layer
.
get_input
(
0
)
next_in_tensor
=
next_layer
.
get_input
(
0
)
next_in_tensor
=
next_layer
.
get_input
(
0
)
...
@@ -83,20 +83,20 @@ def handle_gemm(network, layer_idx, config):
...
@@ -83,20 +83,20 @@ def handle_gemm(network, layer_idx, config):
in_tensor
.
dynamic_range
=
(
tracked_min_input
,
tracked_max_input
)
in_tensor
.
dynamic_range
=
(
tracked_min_input
,
tracked_max_input
)
next_in_tensor
.
dynamic_range
=
(
tracked_min_input
,
tracked_max_input
)
next_in_tensor
.
dynamic_range
=
(
tracked_min_input
,
tracked_max_input
)
# if
activation
bit exists, set the last layer's output type output tensor range
# if
output
bit
s
exists, set the last layer's output type output tensor range
if
'
activation
_bit'
in
config
[
layer
.
name
]:
if
'
output
_bit
s
'
in
config
[
layer
.
name
]:
assert
'tracked_min_
activation
'
in
config
[
layer
.
name
]
assert
'tracked_min_
output
'
in
config
[
layer
.
name
]
assert
'tracked_max_
activation
'
in
config
[
layer
.
name
]
assert
'tracked_max_
output
'
in
config
[
layer
.
name
]
a_bit
=
config
[
layer
.
name
][
'
activation
_bit'
]
a_bit
s
=
config
[
layer
.
name
][
'
output
_bit
s
'
]
tracked_min_
activation
=
config
[
layer
.
name
][
'tracked_min_
activation
'
]
tracked_min_
output
=
config
[
layer
.
name
][
'tracked_min_
output
'
]
tracked_max_
activation
=
config
[
layer
.
name
][
'tracked_max_
activation
'
]
tracked_max_
output
=
config
[
layer
.
name
][
'tracked_max_
output
'
]
# set the last layer's output type
# set the last layer's output type
next_layer
.
set_output_type
(
0
,
Precision_Dict
[
a_bit
])
next_layer
.
set_output_type
(
0
,
Precision_Dict
[
a_bit
s
])
next_out_tensor
=
next_layer
.
get_output
(
0
)
next_out_tensor
=
next_layer
.
get_output
(
0
)
# set the last layer's output tensor range
# set the last layer's output tensor range
next_out_tensor
.
dynamic_range
=
(
tracked_min_
activation
,
tracked_max_
activation
)
next_out_tensor
.
dynamic_range
=
(
tracked_min_
output
,
tracked_max_
output
)
def
build_engine
(
model_file
,
config
=
None
,
extra_layer_bit
=
32
,
strict_datatype
=
False
,
calib
=
None
):
def
build_engine
(
model_file
,
config
=
None
,
extra_layer_bit
s
=
32
,
strict_datatype
=
False
,
calib
=
None
):
"""
"""
This function builds an engine from an onnx model with calibration process.
This function builds an engine from an onnx model with calibration process.
...
@@ -105,12 +105,12 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
...
@@ -105,12 +105,12 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
model_file : str
model_file : str
The path of onnx model
The path of onnx model
config : dict
config : dict
Config recording bit number and name of layers
Config recording bit
s
number and name of layers
extra_layer_bit : int
extra_layer_bit
s
: int
Other layers which are not in config will be quantized to corresponding bit number
Other layers which are not in config will be quantized to corresponding bit
s
number
strict_datatype : bool
strict_datatype : bool
Whether constrain layer bit to the number given in config or not. If true, all the layer
Whether constrain layer bit
s
to the number given in config or not. If true, all the layer
will be set to given bit strictly. Otherwise, these layers will be set automatically by
will be set to given bit
s
strictly. Otherwise, these layers will be set automatically by
tensorrt
tensorrt
calib : numpy array
calib : numpy array
The data using to calibrate quantization model
The data using to calibrate quantization model
...
@@ -135,14 +135,14 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
...
@@ -135,14 +135,14 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
else
:
else
:
builder
.
max_workspace_size
=
common
.
GiB
(
4
)
builder
.
max_workspace_size
=
common
.
GiB
(
4
)
if
extra_layer_bit
==
32
and
config
is
None
:
if
extra_layer_bit
s
==
32
and
config
is
None
:
pass
pass
elif
extra_layer_bit
==
16
and
config
is
None
:
elif
extra_layer_bit
s
==
16
and
config
is
None
:
if
trt_version
==
TRT8
:
if
trt_version
==
TRT8
:
trt_config
.
set_flag
(
trt
.
BuilderFlag
.
FP16
)
trt_config
.
set_flag
(
trt
.
BuilderFlag
.
FP16
)
else
:
else
:
builder
.
fp16_mode
=
True
builder
.
fp16_mode
=
True
elif
extra_layer_bit
==
8
and
config
is
None
:
elif
extra_layer_bit
s
==
8
and
config
is
None
:
# entire model in 8bit mode
# entire model in 8bit mode
if
trt_version
==
TRT8
:
if
trt_version
==
TRT8
:
trt_config
.
set_flag
(
trt
.
BuilderFlag
.
INT8
)
trt_config
.
set_flag
(
trt
.
BuilderFlag
.
INT8
)
...
@@ -180,15 +180,15 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
...
@@ -180,15 +180,15 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
break
break
layer
=
network
.
get_layer
(
i
)
layer
=
network
.
get_layer
(
i
)
if
layer
.
name
in
config
:
if
layer
.
name
in
config
:
w_bit
=
config
[
layer
.
name
][
'weight_bit'
]
w_bit
s
=
config
[
layer
.
name
][
'weight_bit
s
'
]
a_bit
=
config
[
layer
.
name
][
'
activation
_bit'
]
a_bit
s
=
config
[
layer
.
name
][
'
output
_bit
s
'
]
layer
.
precision
=
Precision_Dict
[
w_bit
]
layer
.
precision
=
Precision_Dict
[
w_bit
s
]
layer
.
set_output_type
(
0
,
Precision_Dict
[
a_bit
])
layer
.
set_output_type
(
0
,
Precision_Dict
[
a_bit
s
])
else
:
else
:
# This implementation may be incorrect when output number > 1
# This implementation may be incorrect when output number > 1
for
i
in
range
(
network
.
num_layers
):
for
i
in
range
(
network
.
num_layers
):
if
config
is
None
:
if
config
is
None
:
# no low bit layer need to be set, keep original model
# no low bit
s
layer need to be set, keep original model
break
break
layer
=
network
.
get_layer
(
i
)
layer
=
network
.
get_layer
(
i
)
if
layer
.
name
not
in
config
:
if
layer
.
name
not
in
config
:
...
@@ -198,37 +198,37 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
...
@@ -198,37 +198,37 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
handle_gemm
(
network
,
i
,
config
)
handle_gemm
(
network
,
i
,
config
)
continue
continue
# If weight_bit exists in config, set layer precision and layer's input tensor dynamic range.
# If weight_bit
s
exists in config, set layer precision and layer's input tensor dynamic range.
if
'weight_bit'
in
config
[
layer
.
name
]:
if
'weight_bit
s
'
in
config
[
layer
.
name
]:
assert
'tracked_min_input'
in
config
[
layer
.
name
]
assert
'tracked_min_input'
in
config
[
layer
.
name
]
assert
'tracked_max_input'
in
config
[
layer
.
name
]
assert
'tracked_max_input'
in
config
[
layer
.
name
]
w_bit
=
config
[
layer
.
name
][
'weight_bit'
]
w_bit
s
=
config
[
layer
.
name
][
'weight_bit
s
'
]
tracked_min_input
=
config
[
layer
.
name
][
'tracked_min_input'
]
tracked_min_input
=
config
[
layer
.
name
][
'tracked_min_input'
]
tracked_max_input
=
config
[
layer
.
name
][
'tracked_max_input'
]
tracked_max_input
=
config
[
layer
.
name
][
'tracked_max_input'
]
layer
.
precision
=
Precision_Dict
[
w_bit
]
layer
.
precision
=
Precision_Dict
[
w_bit
s
]
in_tensor
=
layer
.
get_input
(
0
)
in_tensor
=
layer
.
get_input
(
0
)
in_tensor
.
dynamic_range
=
(
tracked_min_input
,
tracked_max_input
)
in_tensor
.
dynamic_range
=
(
tracked_min_input
,
tracked_max_input
)
# If
activation
exists in config, set layer output type and layer's output tensor dynamic range.
# If
output
exists in config, set layer output type and layer's output tensor dynamic range.
if
'
activation
_bit'
in
config
[
layer
.
name
]:
if
'
output
_bit
s
'
in
config
[
layer
.
name
]:
assert
'tracked_min_
activation
'
in
config
[
layer
.
name
]
assert
'tracked_min_
output
'
in
config
[
layer
.
name
]
assert
'tracked_max_
activation
'
in
config
[
layer
.
name
]
assert
'tracked_max_
output
'
in
config
[
layer
.
name
]
a_bit
=
config
[
layer
.
name
][
'
activation
_bit'
]
a_bit
s
=
config
[
layer
.
name
][
'
output
_bit
s
'
]
tracked_min_
activation
=
config
[
layer
.
name
][
'tracked_min_
activation
'
]
tracked_min_
output
=
config
[
layer
.
name
][
'tracked_min_
output
'
]
tracked_max_
activation
=
config
[
layer
.
name
][
'tracked_max_
activation
'
]
tracked_max_
output
=
config
[
layer
.
name
][
'tracked_max_
output
'
]
layer
.
set_output_type
(
0
,
Precision_Dict
[
a_bit
])
layer
.
set_output_type
(
0
,
Precision_Dict
[
a_bit
s
])
out_tensor
=
layer
.
get_output
(
0
)
out_tensor
=
layer
.
get_output
(
0
)
out_tensor
.
dynamic_range
=
(
tracked_min_
activation
,
tracked_max_
activation
)
out_tensor
.
dynamic_range
=
(
tracked_min_
output
,
tracked_max_
output
)
# Build engine and do int8 calibration.
# Build engine and do int8 calibration.
if
trt_version
==
TRT8
:
if
trt_version
==
TRT8
:
engine
=
builder
.
build_engine
(
network
,
trt_config
)
engine
=
builder
.
build_engine
(
network
,
trt_config
)
else
:
else
:
engine
.
builder
.
build_cuda_engine
(
network
)
engine
=
builder
.
build_cuda_engine
(
network
)
return
engine
return
engine
class
ModelSpeedupTensorRT
(
BaseModelSpeedup
):
class
ModelSpeedupTensorRT
(
BaseModelSpeedup
):
def
__init__
(
self
,
model
,
input_shape
,
config
=
None
,
onnx_path
=
"default_model.onnx"
,
extra_layer_bit
=
32
,
strict_datatype
=
True
,
def
__init__
(
self
,
model
,
input_shape
,
config
=
None
,
onnx_path
=
"default_model.onnx"
,
extra_layer_bit
s
=
32
,
strict_datatype
=
True
,
calibrate_type
=
CalibrateType
.
ENTROPY2
,
calib_data_loader
=
None
,
calibration_cache
=
"calibration.cache"
,
batchsize
=
1
,
calibrate_type
=
CalibrateType
.
ENTROPY2
,
calib_data_loader
=
None
,
calibration_cache
=
"calibration.cache"
,
batchsize
=
1
,
input_names
=
[
"actual_input_1"
],
output_names
=
[
"output1"
]):
input_names
=
[
"actual_input_1"
],
output_names
=
[
"output1"
]):
"""
"""
...
@@ -239,14 +239,14 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
...
@@ -239,14 +239,14 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
input_shape : tuple
input_shape : tuple
The input shape of model, shall pass it to torch.onnx.export.
The input shape of model, shall pass it to torch.onnx.export.
config : dict
config : dict
Config recording bit number and name of layers.
Config recording bit
s
number and name of layers.
onnx_path : str
onnx_path : str
The path user want to store onnx model which is converted from pytorch model.
The path user want to store onnx model which is converted from pytorch model.
extra_layer_bit : int
extra_layer_bit
s
: int
Other layers which are not in config will be quantized to corresponding bit number.
Other layers which are not in config will be quantized to corresponding bit
s
number.
strict_datatype : bool
strict_datatype : bool
Whether constrain layer bit to the number given in config or not. If true, all the layer
Whether constrain layer bit
s
to the number given in config or not. If true, all the layer
will be set to given bit strictly. Otherwise, these layers will be set automatically by
will be set to given bit
s
strictly. Otherwise, these layers will be set automatically by
tensorrt.
tensorrt.
calibrate_type : tensorrt.tensorrt.CalibrationAlgoType
calibrate_type : tensorrt.tensorrt.CalibrationAlgoType
The algorithm of calibrating. Please refer to https://docs.nvidia.com/deeplearning/
The algorithm of calibrating. Please refer to https://docs.nvidia.com/deeplearning/
...
@@ -267,7 +267,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
...
@@ -267,7 +267,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
self
.
onnx_path
=
onnx_path
self
.
onnx_path
=
onnx_path
self
.
input_shape
=
input_shape
self
.
input_shape
=
input_shape
self
.
config
=
config
self
.
config
=
config
self
.
extra_layer_bit
=
extra_layer_bit
self
.
extra_layer_bit
s
=
extra_layer_bit
s
self
.
strict_datatype
=
strict_datatype
self
.
strict_datatype
=
strict_datatype
self
.
calibrate_type
=
calibrate_type
self
.
calibrate_type
=
calibrate_type
self
.
calib_data_loader
=
calib_data_loader
self
.
calib_data_loader
=
calib_data_loader
...
@@ -327,7 +327,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
...
@@ -327,7 +327,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
calib
=
calibrator
.
Calibrator
(
calib_data
,
self
.
calibration_cache
,
self
.
batchsize
,
self
.
calibrate_type
)
calib
=
calibrator
.
Calibrator
(
calib_data
,
self
.
calibration_cache
,
self
.
batchsize
,
self
.
calibrate_type
)
# build inference engine with calibration
# build inference engine with calibration
engine
=
build_engine
(
onnx_path
,
self
.
onnx_config
,
self
.
extra_layer_bit
,
self
.
strict_datatype
,
calib
)
engine
=
build_engine
(
onnx_path
,
self
.
onnx_config
,
self
.
extra_layer_bit
s
,
self
.
strict_datatype
,
calib
)
return
engine
.
create_execution_context
()
return
engine
.
create_execution_context
()
def
_tensorrt_build_withoutcalib
(
self
,
onnx_path
):
def
_tensorrt_build_withoutcalib
(
self
,
onnx_path
):
...
@@ -344,7 +344,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
...
@@ -344,7 +344,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
tensorrt.IExecutionContext
tensorrt.IExecutionContext
Context for executing inference using an ICudaEngine
Context for executing inference using an ICudaEngine
"""
"""
engine
=
build_engine
(
onnx_path
,
self
.
onnx_config
,
self
.
extra_layer_bit
,
self
.
strict_datatype
)
engine
=
build_engine
(
onnx_path
,
self
.
onnx_config
,
self
.
extra_layer_bit
s
,
self
.
strict_datatype
)
return
engine
.
create_execution_context
()
return
engine
.
create_execution_context
()
def
inference
(
self
,
test_data
):
def
inference
(
self
,
test_data
):
...
...
test/ut/sdk/test_compressor_torch.py
View file @
86335921
...
@@ -49,7 +49,8 @@ class CompressorTestCase(TestCase):
...
@@ -49,7 +49,8 @@ class CompressorTestCase(TestCase):
}]
}]
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
)
quantizer
.
compress
()
quantizer
.
compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
...
@@ -317,7 +318,9 @@ class CompressorTestCase(TestCase):
...
@@ -317,7 +318,9 @@ class CompressorTestCase(TestCase):
'op_types'
:
[
'ReLU'
]
'op_types'
:
[
'ReLU'
]
}]
}]
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
)
quantizer
.
compress
()
quantizer
.
compress
()
# test quantize
# test quantize
...
@@ -350,14 +353,14 @@ class CompressorTestCase(TestCase):
...
@@ -350,14 +353,14 @@ class CompressorTestCase(TestCase):
eps
=
1e-7
eps
=
1e-7
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
out
=
model
.
relu
(
x
)
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_
activation
,
0
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_
output
,
0
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_
activation
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_
output
,
0.002
,
abs_tol
=
eps
)
quantizer
.
step_with_optimizer
()
quantizer
.
step_with_optimizer
()
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
out
=
model
.
relu
(
x
)
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_
activation
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_
output
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_
activation
,
0.00998
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_
output
,
0.00998
,
abs_tol
=
eps
)
def
test_torch_quantizer_export
(
self
):
def
test_torch_quantizer_export
(
self
):
config_list_qat
=
[{
config_list_qat
=
[{
...
@@ -392,7 +395,8 @@ class CompressorTestCase(TestCase):
...
@@ -392,7 +395,8 @@ class CompressorTestCase(TestCase):
for
config
,
quantize_algorithm
in
zip
(
config_set
,
quantize_algorithm_set
):
for
config
,
quantize_algorithm
in
zip
(
config_set
,
quantize_algorithm_set
):
model
=
TorchModel
()
model
=
TorchModel
()
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
quantize_algorithm
(
model
,
config
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
quantize_algorithm
(
model
,
config
,
optimizer
)
quantizer
.
compress
()
quantizer
.
compress
()
x
=
torch
.
rand
((
1
,
1
,
28
,
28
),
requires_grad
=
True
)
x
=
torch
.
rand
((
1
,
1
,
28
,
28
),
requires_grad
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment