Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
af929fdb
Unverified
Commit
af929fdb
authored
May 18, 2021
by
chenbohua3
Committed by
GitHub
May 18, 2021
Browse files
Add LSQ quantizer (#3503)
parent
761732ab
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
435 additions
and
19 deletions
+435
-19
docs/en_US/Compression/Overview.rst
docs/en_US/Compression/Overview.rst
+2
-0
docs/en_US/Compression/Quantizer.rst
docs/en_US/Compression/Quantizer.rst
+56
-0
examples/model_compress/quantization/LSQ_torch_quantizer.py
examples/model_compress/quantization/LSQ_torch_quantizer.py
+142
-0
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+208
-6
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+27
-13
No files found.
docs/en_US/Compression/Overview.rst
View file @
af929fdb
...
...
@@ -87,6 +87,8 @@ Quantization algorithms compress the original network by reducing the number of
- DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. `Reference Paper <https://arxiv.org/abs/1606.06160>`__
* - `BNN Quantizer <../Compression/Quantizer.rst#bnn-quantizer>`__
- Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1. `Reference Paper <https://arxiv.org/abs/1602.02830>`__
* - `LSQ Quantizer <../Compression/Quantizer.rst#lsq-quantizer>`__
- Learned step size quantization. `Reference Paper <https://arxiv.org/pdf/1902.08153.pdf>`__
Model Speedup
...
...
docs/en_US/Compression/Quantizer.rst
View file @
af929fdb
...
...
@@ -8,6 +8,7 @@ Index of supported quantization algorithms
*
`
QAT
Quantizer
<#
qat
-
quantizer
>`
__
*
`
DoReFa
Quantizer
<#
dorefa
-
quantizer
>`
__
*
`
BNN
Quantizer
<#
bnn
-
quantizer
>`
__
*
`
LSQ
Quantizer
<#
lsq
-
quantizer
>`
__
Naive
Quantizer
---------------
...
...
@@ -86,6 +87,61 @@ note
batch
normalization
folding
is
currently
not
supported
.
----
LSQ
Quantizer
-------------
In
`
LEARNED
STEP
SIZE
QUANTIZATION
<
https
://
arxiv
.
org
/
pdf
/
1902.08153
.
pdf
>`
__
\
,
authors
Steven
K
.
Esser
and
Jeffrey
L
.
McKinstry
provide
an
algorithm
to
train
the
scales
with
gradients
.
..
The
authors
introduce
a
novel
means
to
estimate
and
scale
the
task
loss
gradient
at
each
weight
and
activation
layer
’
s
quantizer
step
size
,
such
that
it
can
be
learned
in
conjunction
with
other
network
parameters
.
Usage
^^^^^
You
can
add
codes
below
before
your
training
codes
.
Three
things
must
be
done
:
1.
configure
which
layer
to
be
quantized
and
which
tensor
(
input
/
output
/
weight
)
of
that
layer
to
be
quantized
.
2.
construct
the
lsq
quantizer
3.
call
the
`
compress
`
API
PyTorch
code
..
code
-
block
::
python
from
nni
.
algorithms
.
compression
.
pytorch
.
quantization
import
LsqQuantizer
model
=
Mnist
()
configure_list
=
[{
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
{
'weight'
:
8
,
'input'
:
8
,
},
'op_names'
:
[
'conv1'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
,},
'op_names'
:
[
'relu1'
]
}]
quantizer
=
LsqQuantizer
(
model
,
configure_list
,
optimizer
)
quantizer
.
compress
()
You
can
view
example
for
more
information
.
:
githublink
:`
examples
/
model_compress
/
quantization
/
LSQ_torch_quantizer
.
py
<
examples
/
model_compress
/
quantization
/
LSQ_torch_quantizer
.
py
>`
User
configuration
for
LSQ
Quantizer
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
common
configuration
needed
by
compression
algorithms
can
be
found
at
`
Specification
of
`
config_list
<./
QuickStart
.
rst
>`
__
.
configuration
needed
by
this
algorithm
:
----
DoReFa
Quantizer
...
...
examples/model_compress/quantization/LSQ_torch_quantizer.py
0 → 100644
View file @
af929fdb
import
torch
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.pytorch.quantization
import
LsqQuantizer
from
nni.compression.pytorch.quantization_speedup
import
ModelSpeedupTensorRT
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
relu1
=
torch
.
nn
.
ReLU6
()
self
.
relu2
=
torch
.
nn
.
ReLU6
()
self
.
relu3
=
torch
.
nn
.
ReLU6
()
self
.
max_pool1
=
torch
.
nn
.
MaxPool2d
(
2
,
2
)
self
.
max_pool2
=
torch
.
nn
.
MaxPool2d
(
2
,
2
)
def
forward
(
self
,
x
):
x
=
self
.
relu1
(
self
.
conv1
(
x
))
x
=
self
.
max_pool1
(
x
)
x
=
self
.
relu2
(
self
.
conv2
(
x
))
x
=
self
.
max_pool2
(
x
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
self
.
relu3
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
):
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
if
batch_idx
%
100
==
0
:
print
(
'{:2.0f}% Loss {}'
.
format
(
100
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
test
(
model
,
device
,
test_loader
):
model
.
eval
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
print
(
'Loss: {} Accuracy: {}%)
\n
'
.
format
(
test_loss
,
100
*
correct
/
len
(
test_loader
.
dataset
)))
def
test_trt
(
engine
,
test_loader
):
test_loss
=
0
correct
=
0
time_elasped
=
0
for
data
,
target
in
test_loader
:
output
,
time
=
engine
.
inference
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
time_elasped
+=
time
test_loss
/=
len
(
test_loader
.
dataset
)
print
(
'Loss: {} Accuracy: {}%'
.
format
(
test_loss
,
100
*
correct
/
len
(
test_loader
.
dataset
)))
print
(
"Inference elapsed_time (whole dataset): {}s"
.
format
(
time_elasped
))
def
main
():
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
trans
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'data'
,
train
=
True
,
download
=
True
,
transform
=
trans
),
batch_size
=
64
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
configure_list
=
[{
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
{
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'conv1'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
,
},
'op_names'
:
[
'relu1'
]
},
{
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
{
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'conv2'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
},
'op_names'
:
[
'relu2'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
},
'op_names'
:
[
'max_pool2'
]
}
]
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
LsqQuantizer
(
model
,
configure_list
,
optimizer
)
quantizer
.
compress
()
model
.
to
(
device
)
for
epoch
in
range
(
40
):
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
model_path
=
"mnist_model.pth"
calibration_path
=
"mnist_calibration.pth"
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
)
test
(
model
,
device
,
test_loader
)
print
(
"calibration_config: "
,
calibration_config
)
batch_size
=
32
input_shape
=
(
batch_size
,
1
,
28
,
28
)
engine
=
ModelSpeedupTensorRT
(
model
,
input_shape
,
config
=
calibration_config
,
batchsize
=
batch_size
)
engine
.
compress
()
test_trt
(
engine
,
test_loader
)
if
__name__
==
'__main__'
:
main
()
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
af929fdb
...
...
@@ -6,9 +6,9 @@ import copy
import
torch
from
schema
import
Schema
,
And
,
Or
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
CompressorSchema
from
nni.compression.pytorch.compressor
import
Quantizer
,
QuantGrad
,
QuantType
from
nni.compression.pytorch.compressor
import
Quantizer
,
QuantForward
,
QuantGrad
,
QuantType
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
]
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
,
'LsqQuantizer'
]
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -59,7 +59,7 @@ def update_ema(biased_ema, value, decay):
float, float
"""
biased_ema
=
biased_ema
*
decay
+
(
1
-
decay
)
*
value
return
biased_ema
return
biased_ema
def
update_quantization_param
(
bits
,
rmin
,
rmax
):
...
...
@@ -146,7 +146,7 @@ class QAT_Quantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
QATGrad
self
.
quant_grad
=
QATGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
for
layer
,
config
in
modules_to_compress
:
...
...
@@ -474,7 +474,7 @@ class BNNQuantizer(Quantizer):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
ClipGrad
self
.
quant_grad
=
ClipGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
...
...
@@ -559,4 +559,206 @@ class BNNQuantizer(Quantizer):
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
return
calibration_config
\ No newline at end of file
return
calibration_config
class
LsqQuantizer
(
Quantizer
):
"""Quantizer defined in:
Learned Step Size Quantization (ICLR 2020)
https://arxiv.org/pdf/1902.08153.pdf
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
"""
Parameters
----------
model : torch.nn.Module
the model to be quantized
config_list : list of dict
list of configurations for quantization
supported keys for dict:
- quant_types : list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'
- quant_bits : int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length
- quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
QuantForward
()
modules_to_compress
=
self
.
get_modules_to_compress
()
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_parameter
(
"weight_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
# todo: support per-channel quantization for weight since TensorRT use it for conv weight
q_bit
=
get_bits_length
(
config
,
"weight"
)
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
Tensor
([
q_bit
]))
qmax
=
2
**
(
q_bit
-
1
)
-
1
qmin
=
-
2
**
(
q_bit
-
1
)
init_weight_scale
=
layer
.
module
.
weight
.
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
layer
.
module
.
weight_scale
=
torch
.
nn
.
Parameter
(
init_weight_scale
)
layer
.
module
.
weight_qmax
=
qmax
layer
.
module
.
weight_qmin
=
qmin
self
.
optimizer
.
add_param_group
({
"params"
:
layer
.
module
.
weight_scale
})
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
# scale of activation will be initialized using the first batch data
layer
.
module
.
register_parameter
(
"output_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
q_bit
=
get_bits_length
(
config
,
"output"
)
layer
.
module
.
register_buffer
(
'output_bit'
,
torch
.
Tensor
([
q_bit
]))
qmax
=
2
**
(
q_bit
-
1
)
-
1
qmin
=
-
2
**
(
q_bit
-
1
)
layer
.
module
.
output_qmax
=
qmax
layer
.
module
.
output_qmin
=
qmin
self
.
optimizer
.
add_param_group
({
"params"
:
layer
.
module
.
output_scale
})
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
# scale of input will be initialized using the first batch data
layer
.
module
.
register_parameter
(
"input_scale"
,
torch
.
nn
.
Parameter
(
torch
.
Tensor
([
1.0
])))
q_bit
=
get_bits_length
(
config
,
"input"
)
layer
.
module
.
register_buffer
(
'input_bit'
,
torch
.
Tensor
([
q_bit
]))
qmax
=
2
**
(
q_bit
-
1
)
-
1
qmin
=
-
2
**
(
q_bit
-
1
)
layer
.
module
.
input_qmax
=
qmax
layer
.
module
.
input_qmin
=
qmin
self
.
optimizer
.
add_param_group
({
"params"
:
layer
.
module
.
input_scale
})
@
staticmethod
def
grad_scale
(
x
,
scale
):
"""
Used to scale the gradient. Give tensor `x`, we have `y=grad_scale(x, scale)=x` in the forward pass,
which means that this function will not change the value of `x`. In the backward pass, we have:
:math:`
\f
rac{
\a
lpha_L}{
\a
lpha_x}=
\f
rac{
\a
lpha_L}{
\a
lpha_y}*
\f
rac{
\a
lpha_y}{
\a
lpha_x}=sclae*
\f
rac{
\a
lpha_L}{
\a
lpha_x}`
This means that the origin gradient of x is scaled by a factor of `scale`. Applying this function
to a nn.Parameter will scale the gradient of it without changing its value.
"""
y
=
x
y_grad
=
x
*
scale
return
(
y
-
y_grad
).
detach
()
+
y_grad
@
staticmethod
def
round_pass
(
x
):
"""
A simple way to achieve STE operation.
"""
y
=
x
.
round
()
y_grad
=
x
return
(
y
-
y_grad
).
detach
()
+
y_grad
def
quantize
(
self
,
x
,
scale
,
qmin
,
qmax
):
grad_scale_factor
=
1.0
/
((
qmax
*
x
.
numel
())
**
0.5
)
scale
=
self
.
grad_scale
(
scale
,
grad_scale_factor
)
x
=
x
/
scale
x
=
torch
.
clamp
(
x
,
qmin
,
qmax
)
x
=
self
.
round_pass
(
x
)
x
=
x
*
scale
return
x
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
module
=
wrapper
.
module
# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize
# bias
old_weight
=
module
.
old_weight
weight
=
self
.
quantize
(
old_weight
,
module
.
weight_scale
,
module
.
weight_qmin
,
module
.
weight_qmax
)
module
.
weight
=
weight
return
weight
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
module
=
wrapper
.
module
# initialize the scale
if
self
.
bound_model
.
steps
==
1
:
qmax
=
module
.
output_qmax
init_oup_scale
=
output
.
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
module
.
output_scale
.
data
=
init_oup_scale
output
=
self
.
quantize
(
output
,
module
.
output_scale
,
module
.
output_qmin
,
module
.
output_qmax
)
return
output
def
quantize_input
(
self
,
*
inputs
,
wrapper
,
**
kwargs
):
# This is hacky since it is not recommended to modify a tuple
# NB: support layers with multi inputs
module
=
wrapper
.
module
# initialize the scale
if
self
.
bound_model
.
steps
==
1
:
qmax
=
module
.
input_qmax
init_oup_scale
=
inputs
[
0
].
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
module
.
input_scale
.
data
=
init_oup_scale
new_input
=
self
.
quantize
(
inputs
[
0
],
module
.
input_scale
,
module
.
input_qmin
,
module
.
input_qmax
)
list_inp
=
list
(
inputs
)
list_inp
[
0
]
=
new_input
return
tuple
(
list_inp
)
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert
model_path
is
not
None
,
'model_path must be specified'
self
.
_unwrap_model
()
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
hasattr
(
module
,
'input_bit'
)
or
hasattr
(
module
,
'output_bit'
):
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_bit'
):
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
abs_max_input
=
float
(
module
.
input_scale
*
module
.
input_qmax
)
calibration_config
[
name
][
'tracked_min_input'
]
=
-
abs_max_input
calibration_config
[
name
][
'tracked_max_input'
]
=
abs_max_input
if
hasattr
(
module
,
'output_bit'
):
calibration_config
[
name
][
'activation_bit'
]
=
int
(
module
.
output_bit
)
abs_max_output
=
float
(
module
.
output_scale
*
module
.
output_qmax
)
calibration_config
[
name
][
'tracked_min_activation'
]
=
-
abs_max_output
calibration_config
[
name
][
'tracked_max_activation'
]
=
abs_max_output
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
return
calibration_config
def
_del_simulated_attr
(
self
,
module
):
"""
delete redundant parameters in quantize module
"""
del_attr_list
=
[
'old_weight'
,
'tracked_min_input'
,
'tracked_max_input'
,
'tracked_min_activation'
,
\
'tracked_max_activation'
,
'output_scale'
,
'input_scale'
,
'weight_scale'
,
'weight_bit'
,
'output_bit'
,
'input_bit'
]
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
def
step_with_optimizer
(
self
):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self
.
bound_model
.
steps
+=
1
nni/compression/pytorch/compressor.py
View file @
af929fdb
...
...
@@ -474,13 +474,13 @@ class QuantizerModuleWrapper(torch.nn.Module):
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
inputs
=
self
.
quantizer
.
quant_grad
.
apply
(
inputs
=
self
.
quantizer
.
quant_grad
(
inputs
,
QuantType
.
QUANT_INPUT
,
self
)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
self
.
quantizer
.
quant_grad
.
apply
(
self
.
quantizer
.
quant_grad
(
self
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
self
,
inputs
[
0
])
...
...
@@ -489,12 +489,13 @@ class QuantizerModuleWrapper(torch.nn.Module):
result
=
self
.
module
(
*
inputs
)
if
'output'
in
self
.
config
[
'quant_types'
]:
result
=
self
.
quantizer
.
quant_grad
.
apply
(
result
=
self
.
quantizer
.
quant_grad
(
result
,
QuantType
.
QUANT_OUTPUT
,
self
)
return
result
class
Quantizer
(
Compressor
):
"""
Base quantizer for pytorch quantizer
...
...
@@ -502,7 +503,7 @@ class Quantizer(Compressor):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
QuantGrad
self
.
quant_grad
=
QuantGrad
.
apply
if
self
.
optimizer
is
not
None
:
self
.
patch_optimizer
(
self
.
step_with_optimizer
)
for
wrapper
in
self
.
get_modules_wrapper
():
...
...
@@ -719,15 +720,7 @@ class QuantGrad(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
if
quant_type
==
QuantType
.
QUANT_INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
raise
ValueError
(
"unrecognized QuantType."
)
output
=
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
,
**
kwargs
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
QType_Dict
[
quant_type
])
qmin
,
qmax
=
torch
.
Tensor
([
0
]).
to
(
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
to
(
tensor
.
device
)
...
...
@@ -750,3 +743,24 @@ def _check_weight(module):
return
isinstance
(
module
.
weight
.
data
,
torch
.
Tensor
)
except
AttributeError
:
return
False
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
if
quant_type
==
QuantType
.
QUANT_INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
*
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
raise
ValueError
(
"unrecognized QuantType."
)
return
output
class
QuantForward
(
torch
.
nn
.
Module
):
"""
Base class for executing quantization operations. This is for quantization algorithms
that do not need to customize gradient.
"""
def
forward
(
self
,
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
return
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment