Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3fc79c74
Unverified
Commit
3fc79c74
authored
Sep 15, 2021
by
lin bin
Committed by
GitHub
Sep 15, 2021
Browse files
[Quantization] support load_calibration_config (#4163)
parent
9a9cb3d9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
36 deletions
+116
-36
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+69
-36
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+18
-0
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+29
-0
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
3fc79c74
...
@@ -388,15 +388,18 @@ class QAT_Quantizer(Quantizer):
...
@@ -388,15 +388,18 @@ class QAT_Quantizer(Quantizer):
module
.
register_buffer
(
"scale"
,
torch
.
tensor
([
1.0
]))
module
.
register_buffer
(
"scale"
,
torch
.
tensor
([
1.0
]))
module
.
register_buffer
(
'ema_decay'
,
torch
.
tensor
([
0.99
]))
module
.
register_buffer
(
'ema_decay'
,
torch
.
tensor
([
0.99
]))
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
module
.
register_buffer
(
'weight_bits'
,
torch
.
zeros
(
1
))
weight_bits
=
get_bits_length
(
config
,
'weight'
)
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
Tensor
([
int
(
weight_bits
)]))
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
module
.
register_buffer
(
'input_bits'
,
torch
.
zeros
(
1
))
input_bits
=
get_bits_length
(
config
,
'input'
)
module
.
register_buffer
(
'tracked_min_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_input'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'tracked_max_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'input_bits'
,
torch
.
Tensor
([
int
(
input_bits
)]))
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
module
.
register_buffer
(
'output_bits'
,
torch
.
zeros
(
1
))
output_bits
=
get_bits_length
(
config
,
'output'
)
module
.
register_buffer
(
'tracked_min_output'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'output_bits'
,
torch
.
Tensor
([
int
(
output_bits
)]))
module
.
register_buffer
(
'tracked_max_output'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_output'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_output'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
def
_del_simulated_attr
(
self
,
module
):
...
@@ -484,7 +487,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -484,7 +487,7 @@ class QAT_Quantizer(Quantizer):
config
=
wrapper
.
config
config
=
wrapper
.
config
module
=
wrapper
.
module
module
=
wrapper
.
module
weight
=
module
.
weight
weight
=
module
.
weight
weight_bits
=
get_bits_length
(
config
,
'
weight
'
)
weight_bits
=
int
(
module
.
weight
_bits
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
...
@@ -501,20 +504,13 @@ class QAT_Quantizer(Quantizer):
...
@@ -501,20 +504,13 @@ class QAT_Quantizer(Quantizer):
module
.
zero_point
.
copy_
(
zero_point
)
module
.
zero_point
.
copy_
(
zero_point
)
weight
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
weight
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
weight
=
self
.
_dequantize
(
module
,
weight
)
weight
=
self
.
_dequantize
(
module
,
weight
)
module
.
weight_bits
=
torch
.
Tensor
([
weight_bits
])
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
return
weight
return
weight
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
config
=
wrapper
.
config
module
=
wrapper
.
module
module
=
wrapper
.
module
input_bits
=
get_bits_length
(
config
,
'input'
)
input_bits
=
int
(
module
.
input_bits
)
module
.
input_bit
=
torch
.
tensor
([
input_bits
])
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
input_bits
>=
1
,
"quant bits length should be at least 1"
assert
input_bits
>=
1
,
"quant bits length should be at least 1"
...
@@ -544,8 +540,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -544,8 +540,7 @@ class QAT_Quantizer(Quantizer):
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
config
=
wrapper
.
config
module
=
wrapper
.
module
module
=
wrapper
.
module
output_bits
=
get_bits_length
(
config
,
'output'
)
output_bits
=
int
(
module
.
output_bits
)
module
.
output_bits
=
torch
.
Tensor
([
output_bits
])
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
...
@@ -574,6 +569,25 @@ class QAT_Quantizer(Quantizer):
...
@@ -574,6 +569,25 @@ class QAT_Quantizer(Quantizer):
out
=
self
.
_dequantize
(
module
,
out
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
return
out
def
load_calibration_config
(
self
,
calibration_config
):
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
_
in
modules_to_compress
:
name
,
module
=
layer
.
name
,
layer
.
module
if
name
not
in
calibration_config
:
if
hasattr
(
module
,
'weight_bits'
)
or
hasattr
(
module
,
'output_bits'
)
or
hasattr
(
module
,
'input_bits'
):
logger
.
warning
(
f
"Can not find module
{
name
}
's parameter in input config."
)
continue
if
hasattr
(
module
,
'weight_bits'
):
assert
calibration_config
[
name
][
'weight_bits'
]
==
module
.
weight_bits
,
f
"weight bits of module
{
name
}
fail to match"
if
hasattr
(
module
,
'input_bits'
):
assert
calibration_config
[
name
][
'input_bits'
]
==
module
.
input_bits
,
f
"input bits of module
{
name
}
fail to match"
module
.
tracked_min_input
.
data
=
torch
.
Tensor
([
calibration_config
[
name
][
'tracked_min_input'
]])
module
.
tracked_max_input
.
data
=
torch
.
Tensor
([
calibration_config
[
name
][
'tracked_max_input'
]])
if
hasattr
(
module
,
'output_bits'
):
assert
calibration_config
[
name
][
'output_bits'
]
==
module
.
output_bits
,
f
"output bits of module
{
name
}
fail to match"
module
.
tracked_min_output
.
data
=
torch
.
Tensor
([
calibration_config
[
name
][
'tracked_min_output'
]])
module
.
tracked_max_output
.
data
=
torch
.
Tensor
([
calibration_config
[
name
][
'tracked_max_output'
]])
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
"""
Export quantized model weights and calibration parameters(optional)
Export quantized model weights and calibration parameters(optional)
...
@@ -620,8 +634,8 @@ class QAT_Quantizer(Quantizer):
...
@@ -620,8 +634,8 @@ class QAT_Quantizer(Quantizer):
module
.
register_parameter
(
'bias'
,
actual_bias
)
module
.
register_parameter
(
'bias'
,
actual_bias
)
else
:
else
:
setattr
(
module
,
'bias'
,
None
)
setattr
(
module
,
'bias'
,
None
)
if
hasattr
(
module
,
'input_bit'
):
if
hasattr
(
module
,
'input_bit
s
'
):
calibration_config
[
name
][
'input_bits'
]
=
int
(
module
.
input_bit
)
calibration_config
[
name
][
'input_bits'
]
=
int
(
module
.
input_bit
s
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
...
@@ -655,7 +669,8 @@ class DoReFaQuantizer(Quantizer):
...
@@ -655,7 +669,8 @@ class DoReFaQuantizer(Quantizer):
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
zeros
(
1
))
weight_bits
=
get_bits_length
(
config
,
'weight'
)
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
Tensor
([
int
(
weight_bits
)]))
self
.
bound_model
.
to
(
device
)
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
def
_del_simulated_attr
(
self
,
module
):
...
@@ -690,13 +705,12 @@ class DoReFaQuantizer(Quantizer):
...
@@ -690,13 +705,12 @@ class DoReFaQuantizer(Quantizer):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
weight
=
wrapper
.
module
.
weight
weight
=
wrapper
.
module
.
weight
weight_bits
=
get_bits_length
(
wrapper
.
config
,
'
weight
'
)
weight_bits
=
int
(
wrapper
.
module
.
weight
_bits
)
weight
=
weight
.
tanh
()
weight
=
weight
.
tanh
()
weight
=
weight
/
(
2
*
weight
.
abs
().
max
())
+
0.5
weight
=
weight
/
(
2
*
weight
.
abs
().
max
())
+
0.5
weight
=
self
.
quantize
(
weight
,
weight_bits
)
weight
=
self
.
quantize
(
weight
,
weight_bits
)
weight
=
2
*
weight
-
1
weight
=
2
*
weight
-
1
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight_bits
=
torch
.
Tensor
([
weight_bits
])
# wrapper.module.weight.data = weight
# wrapper.module.weight.data = weight
return
weight
return
weight
...
@@ -764,7 +778,8 @@ class BNNQuantizer(Quantizer):
...
@@ -764,7 +778,8 @@ class BNNQuantizer(Quantizer):
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
zeros
(
1
))
weight_bits
=
get_bits_length
(
config
,
'weight'
)
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
Tensor
([
int
(
weight_bits
)]))
self
.
bound_model
.
to
(
device
)
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
def
_del_simulated_attr
(
self
,
module
):
...
@@ -890,10 +905,10 @@ class LsqQuantizer(Quantizer):
...
@@ -890,10 +905,10 @@ class LsqQuantizer(Quantizer):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_parameter
(
"weight_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
layer
.
module
.
register_parameter
(
"weight_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
# todo: support per-channel quantization for weight since TensorRT use it for conv weight
# todo: support per-channel quantization for weight since TensorRT use it for conv weight
q
_bits
=
get_bits_length
(
config
,
"weight"
)
weight
_bits
=
get_bits_length
(
config
,
"weight"
)
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
Tensor
([
q
_bits
]))
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
Tensor
([
weight
_bits
]))
qmax
=
2
**
(
q
_bits
-
1
)
-
1
qmax
=
2
**
(
weight
_bits
-
1
)
-
1
qmin
=
-
2
**
(
q
_bits
-
1
)
qmin
=
-
2
**
(
weight
_bits
-
1
)
init_weight_scale
=
layer
.
module
.
weight
.
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
init_weight_scale
=
layer
.
module
.
weight
.
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
layer
.
module
.
weight_scale
=
torch
.
nn
.
Parameter
(
init_weight_scale
)
layer
.
module
.
weight_scale
=
torch
.
nn
.
Parameter
(
init_weight_scale
)
layer
.
module
.
weight_qmax
=
qmax
layer
.
module
.
weight_qmax
=
qmax
...
@@ -904,10 +919,10 @@ class LsqQuantizer(Quantizer):
...
@@ -904,10 +919,10 @@ class LsqQuantizer(Quantizer):
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
# scale of output will be initialized using the first batch data
# scale of output will be initialized using the first batch data
layer
.
module
.
register_parameter
(
"output_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
layer
.
module
.
register_parameter
(
"output_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
q
_bits
=
get_bits_length
(
config
,
"output"
)
output
_bits
=
get_bits_length
(
config
,
"output"
)
layer
.
module
.
register_buffer
(
'output_bits'
,
torch
.
Tensor
([
q
_bits
]))
layer
.
module
.
register_buffer
(
'output_bits'
,
torch
.
Tensor
([
output
_bits
]))
qmax
=
2
**
(
q
_bits
-
1
)
-
1
qmax
=
2
**
(
output
_bits
-
1
)
-
1
qmin
=
-
2
**
(
q
_bits
-
1
)
qmin
=
-
2
**
(
output
_bits
-
1
)
layer
.
module
.
output_qmax
=
qmax
layer
.
module
.
output_qmax
=
qmax
layer
.
module
.
output_qmin
=
qmin
layer
.
module
.
output_qmin
=
qmin
...
@@ -916,10 +931,10 @@ class LsqQuantizer(Quantizer):
...
@@ -916,10 +931,10 @@ class LsqQuantizer(Quantizer):
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
# scale of input will be initialized using the first batch data
# scale of input will be initialized using the first batch data
layer
.
module
.
register_parameter
(
"input_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
layer
.
module
.
register_parameter
(
"input_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
q
_bits
=
get_bits_length
(
config
,
"input"
)
input
_bits
=
get_bits_length
(
config
,
"input"
)
layer
.
module
.
register_buffer
(
'input_bits'
,
torch
.
Tensor
([
q
_bits
]))
layer
.
module
.
register_buffer
(
'input_bits'
,
torch
.
Tensor
([
input
_bits
]))
qmax
=
2
**
(
q
_bits
-
1
)
-
1
qmax
=
2
**
(
input
_bits
-
1
)
-
1
qmin
=
-
2
**
(
q
_bits
-
1
)
qmin
=
-
2
**
(
input
_bits
-
1
)
layer
.
module
.
input_qmax
=
qmax
layer
.
module
.
input_qmax
=
qmax
layer
.
module
.
input_qmin
=
qmin
layer
.
module
.
input_qmin
=
qmin
...
@@ -993,6 +1008,24 @@ class LsqQuantizer(Quantizer):
...
@@ -993,6 +1008,24 @@ class LsqQuantizer(Quantizer):
inputs
=
self
.
quantize
(
inputs
,
module
.
input_scale
,
module
.
input_qmin
,
module
.
input_qmax
)
inputs
=
self
.
quantize
(
inputs
,
module
.
input_scale
,
module
.
input_qmin
,
module
.
input_qmax
)
return
inputs
return
inputs
def
load_calibration_config
(
self
,
calibration_config
):
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
_
in
modules_to_compress
:
name
,
module
=
layer
.
name
,
layer
.
module
if
name
not
in
calibration_config
:
if
hasattr
(
module
,
'weight_bits'
)
or
hasattr
(
module
,
'output_bits'
)
or
hasattr
(
module
,
'input_bits'
):
logger
.
warning
(
f
"Can not find module
{
name
}
's parameter in input config."
)
continue
if
hasattr
(
module
,
'weight_bits'
):
assert
calibration_config
[
name
][
'weight_bits'
]
==
int
(
module
.
weight_bits
),
f
"weight bits of module
{
name
}
fail to match"
if
hasattr
(
module
,
'input_bits'
):
assert
calibration_config
[
name
][
'input_bits'
]
==
int
(
module
.
input_bits
),
f
"input bits of module
{
name
}
fail to match"
module
.
input_scale
.
data
=
torch
.
Tensor
([
float
(
calibration_config
[
name
][
'tracked_max_input'
]
/
module
.
input_qmax
)])
if
hasattr
(
module
,
'output_bits'
):
assert
calibration_config
[
name
][
'output_bits'
]
==
int
(
module
.
output_bits
),
f
"output bits of module
{
name
}
fail to match"
module
.
output_scale
.
data
=
torch
.
Tensor
([
float
(
calibration_config
[
name
][
'tracked_max_output'
]
/
module
.
output_qmax
)])
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
"""
Export quantized model weights and calibration parameters(optional)
Export quantized model weights and calibration parameters(optional)
...
...
nni/compression/pytorch/compressor.py
View file @
3fc79c74
...
@@ -803,6 +803,24 @@ class Quantizer(Compressor):
...
@@ -803,6 +803,24 @@ class Quantizer(Compressor):
"""
"""
raise
NotImplementedError
(
'Quantizer must overload export_model()'
)
raise
NotImplementedError
(
'Quantizer must overload export_model()'
)
def
load_calibration_config
(
self
,
calibration_config
):
"""
This function aims to help quantizer set quantization parameters by
loading from a calibration_config which is exported by other quantizer
or itself. The main usage of this function is helping quantize aware training
quantizer set appropriate initial parameters so that the training process will
be much more flexible and converges quickly. What's more, it can also enable
quantizer resume quantization model by loading parameters from config.
Parameters
----------
calibration_config : dict
dict which saves quantization parameters, quantizer can export itself
calibration config.
eg, calibration_config = quantizer.export_model(model_path, calibration_path)
"""
raise
NotImplementedError
(
'Quantizer must overload export_model()'
)
def
find_conv_bn_patterns
(
self
,
model
,
dummy_input
):
def
find_conv_bn_patterns
(
self
,
model
,
dummy_input
):
"""
"""
Find all Conv-BN patterns, used for batch normalization folding
Find all Conv-BN patterns, used for batch normalization folding
...
...
test/ut/sdk/test_compressor_torch.py
View file @
3fc79c74
...
@@ -445,6 +445,35 @@ class CompressorTestCase(TestCase):
...
@@ -445,6 +445,35 @@ class CompressorTestCase(TestCase):
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
assert
calibration_config
is
not
None
assert
calibration_config
is
not
None
def
test_quantizer_load_calibration_config
(
self
):
configure_list
=
[{
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
{
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'conv1'
,
'conv2'
]
},
{
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'fc1'
,
'fc2'
],
}]
quantize_algorithm_set
=
[
torch_quantizer
.
ObserverQuantizer
,
torch_quantizer
.
QAT_Quantizer
,
torch_quantizer
.
LsqQuantizer
]
calibration_config
=
None
for
quantize_algorithm
in
quantize_algorithm_set
:
model
=
TorchModel
().
eval
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
quantize_algorithm
(
model
,
configure_list
,
optimizer
)
quantizer
.
compress
()
if
calibration_config
is
not
None
:
quantizer
.
load_calibration_config
(
calibration_config
)
model_path
=
"test_model.pth"
calibration_path
=
"test_calibration.pth"
onnx_path
=
"test_model.onnx"
input_shape
=
(
1
,
1
,
28
,
28
)
device
=
torch
.
device
(
"cpu"
)
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
def
test_torch_pruner_validation
(
self
):
def
test_torch_pruner_validation
(
self
):
# test bad configuraiton
# test bad configuraiton
pruner_classes
=
[
torch_pruner
.
__dict__
[
x
]
for
x
in
\
pruner_classes
=
[
torch_pruner
.
__dict__
[
x
]
for
x
in
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment