Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f0e3c584
Unverified
Commit
f0e3c584
authored
Apr 09, 2021
by
lin bin
Committed by
GitHub
Apr 09, 2021
Browse files
Combine tensorrt tool with NNI quantization algorithms. (#3488)
parent
80bc9537
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
977 additions
and
29 deletions
+977
-29
examples/model_compress/quantization/mixed_precision_speedup_mnist.py
...el_compress/quantization/mixed_precision_speedup_mnist.py
+169
-0
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+28
-14
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+4
-6
nni/compression/pytorch/quantization_speedup/__init__.py
nni/compression/pytorch/quantization_speedup/__init__.py
+1
-0
nni/compression/pytorch/quantization_speedup/backend.py
nni/compression/pytorch/quantization_speedup/backend.py
+51
-0
nni/compression/pytorch/quantization_speedup/calibrator.py
nni/compression/pytorch/quantization_speedup/calibrator.py
+99
-0
nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py
...pression/pytorch/quantization_speedup/frontend_to_onnx.py
+148
-0
nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
...ssion/pytorch/quantization_speedup/integrated_tensorrt.py
+381
-0
nni/compression/pytorch/quantization_speedup/trt_pycuda.py
nni/compression/pytorch/quantization_speedup/trt_pycuda.py
+86
-0
pylintrc
pylintrc
+2
-2
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+8
-7
No files found.
examples/model_compress/quantization/mixed_precision_speedup_mnist.py
0 → 100644
View file @
f0e3c584
import
torch
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.compression.pytorch.quantization_speedup
import
ModelSpeedupTensorRT
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
relu1
=
torch
.
nn
.
ReLU6
()
self
.
relu2
=
torch
.
nn
.
ReLU6
()
self
.
relu3
=
torch
.
nn
.
ReLU6
()
self
.
max_pool1
=
torch
.
nn
.
MaxPool2d
(
2
,
2
)
self
.
max_pool2
=
torch
.
nn
.
MaxPool2d
(
2
,
2
)
def
forward
(
self
,
x
):
x
=
self
.
relu1
(
self
.
conv1
(
x
))
x
=
self
.
max_pool1
(
x
)
x
=
self
.
relu2
(
self
.
conv2
(
x
))
x
=
self
.
max_pool2
(
x
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
self
.
relu3
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
train
(
model
,
device
,
train_loader
,
optimizer
):
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
if
batch_idx
%
100
==
0
:
print
(
'{:2.0f}% Loss {}'
.
format
(
100
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
test
(
model
,
device
,
test_loader
):
model
.
eval
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
print
(
'Loss: {} Accuracy: {}%)
\n
'
.
format
(
test_loss
,
100
*
correct
/
len
(
test_loader
.
dataset
)))
def
test_trt
(
engine
,
test_loader
):
test_loss
=
0
correct
=
0
time_elasped
=
0
for
data
,
target
in
test_loader
:
output
,
time
=
engine
.
inference
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
time_elasped
+=
time
test_loss
/=
len
(
test_loader
.
dataset
)
print
(
'Loss: {} Accuracy: {}%'
.
format
(
test_loss
,
100
*
correct
/
len
(
test_loader
.
dataset
)))
print
(
"Inference elapsed_time (whole dataset): {}s"
.
format
(
time_elasped
))
def
post_training_quantization_example
(
train_loader
,
test_loader
,
device
):
model
=
Mnist
()
config
=
{
'conv1'
:{
'weight_bit'
:
8
,
'activation_bit'
:
8
},
'conv2'
:{
'weight_bit'
:
32
,
'activation_bit'
:
32
},
'fc1'
:{
'weight_bit'
:
16
,
'activation_bit'
:
16
},
'fc2'
:{
'weight_bit'
:
8
,
'activation_bit'
:
8
}
}
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
model
.
to
(
device
)
for
epoch
in
range
(
1
):
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
batch_size
=
32
input_shape
=
(
batch_size
,
1
,
28
,
28
)
engine
=
ModelSpeedupTensorRT
(
model
,
input_shape
,
config
=
config
,
calib_data_loader
=
train_loader
,
batchsize
=
batch_size
)
engine
.
compress
()
test_trt
(
engine
,
test_loader
)
def
quantization_aware_training_example
(
train_loader
,
test_loader
,
device
):
model
=
Mnist
()
configure_list
=
[{
'quant_types'
:
[
'weight'
,
'output'
],
'quant_bits'
:
{
'weight'
:
8
,
'output'
:
8
},
'op_names'
:
[
'conv1'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
},
'op_names'
:
[
'relu1'
]
},
{
'quant_types'
:
[
'weight'
,
'output'
],
'quant_bits'
:
{
'weight'
:
8
,
'output'
:
8
},
'op_names'
:
[
'conv2'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
},
'op_names'
:
[
'relu2'
]
}
]
# finetune the model by using QAT
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
QAT_Quantizer
(
model
,
configure_list
,
optimizer
)
quantizer
.
compress
()
model
.
to
(
device
)
for
epoch
in
range
(
1
):
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
model_path
=
"mnist_model.pth"
calibration_path
=
"mnist_calibration.pth"
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
)
test
(
model
,
device
,
test_loader
)
print
(
"calibration_config: "
,
calibration_config
)
batch_size
=
32
input_shape
=
(
batch_size
,
1
,
28
,
28
)
engine
=
ModelSpeedupTensorRT
(
model
,
input_shape
,
config
=
calibration_config
,
batchsize
=
batch_size
)
engine
.
compress
()
test_trt
(
engine
,
test_loader
)
def
main
():
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
trans
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'data'
,
train
=
True
,
download
=
True
,
transform
=
trans
),
batch_size
=
64
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
# post-training quantization on TensorRT
post_training_quantization_example
(
train_loader
,
test_loader
,
device
)
# combine NNI quantization algorithm QAT with backend framework TensorRT
quantization_aware_training_example
(
train_loader
,
test_loader
,
device
)
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
f0e3c584
...
...
@@ -152,22 +152,23 @@ class QAT_Quantizer(Quantizer):
for
layer
,
config
in
modules_to_compress
:
layer
.
module
.
register_buffer
(
"zero_point"
,
torch
.
Tensor
([
0.0
]))
layer
.
module
.
register_buffer
(
"scale"
,
torch
.
Tensor
([
1.0
]))
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_input'
,
torch
.
zeros
(
1
))
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'activation_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
layer
.
module
.
register_buffer
(
'tracked_min_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_activation'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_activation'
,
torch
.
zeros
(
1
))
def
_del_simulated_attr
(
self
,
module
):
"""
delete redundant parameters in quantize module
"""
del_attr_list
=
[
'old_weight'
,
'ema_decay'
,
'tracked_min_
biased
'
,
'tracked_max_
biased
'
,
'tracked_min'
,
\
'tracked_max'
,
'scale'
,
'zero_point'
,
'weight_bit'
,
'activation_bit'
]
del_attr_list
=
[
'old_weight'
,
'ema_decay'
,
'tracked_min_
activation
'
,
'tracked_max_
activation
'
,
'tracked_min
_input
'
,
\
'tracked_max
_input
'
,
'scale'
,
'zero_point'
,
'weight_bit'
,
'activation_bit'
]
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
...
...
@@ -243,15 +244,26 @@ class QAT_Quantizer(Quantizer):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
input
=
kwargs
[
'input_tensor'
]
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_weight
.
data
)
weight_bits
=
get_bits_length
(
config
,
'weight'
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
# we dont update weight in evaluation stage
if
quant_start_step
>
self
.
bound_model
.
steps
or
not
wrapper
.
training
:
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_input
,
module
.
tracked_max_input
=
torch
.
min
(
input
),
torch
.
max
(
input
)
return
weight
if
not
wrapper
.
training
:
return
weight
current_min
,
current_max
=
torch
.
min
(
input
),
torch
.
max
(
input
)
module
.
tracked_min_input
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_input
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
module
.
ema_decay
)
# if bias exists, quantize bias to uint32
if
hasattr
(
wrapper
.
module
,
'bias'
)
and
wrapper
.
module
.
bias
is
not
None
:
bias
=
wrapper
.
module
.
bias
.
data
...
...
@@ -281,17 +293,17 @@ class QAT_Quantizer(Quantizer):
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_
biased
,
module
.
tracked_max_
biased
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_
activation
,
module
.
tracked_max_
activation
=
torch
.
min
(
output
),
torch
.
max
(
output
)
return
output
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_
biased
=
update_ema
(
module
.
tracked_min_
biased
,
current_min
,
module
.
tracked_min_
activation
=
update_ema
(
module
.
tracked_min_
activation
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_
biased
=
update_ema
(
module
.
tracked_max_
biased
,
current_max
,
module
.
tracked_max_
activation
=
update_ema
(
module
.
tracked_max_
activation
,
current_max
,
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min_
biased
,
module
.
tracked_max_
biased
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min_
activation
,
module
.
tracked_max_
activation
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
...
...
@@ -327,10 +339,12 @@ class QAT_Quantizer(Quantizer):
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_bit'
):
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
if
hasattr
(
module
,
'activation_bit'
):
calibration_config
[
name
][
'activation_bit'
]
=
int
(
module
.
activation_bit
)
calibration_config
[
name
][
'tracked_min'
]
=
float
(
module
.
tracked_min_
biased
)
calibration_config
[
name
][
'tracked_max'
]
=
float
(
module
.
tracked_max_
biased
)
calibration_config
[
name
][
'tracked_min
_activation
'
]
=
float
(
module
.
tracked_min_
activation
)
calibration_config
[
name
][
'tracked_max
_activation
'
]
=
float
(
module
.
tracked_max_
activation
)
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
...
...
nni/compression/pytorch/compressor.py
View file @
f0e3c584
...
...
@@ -483,7 +483,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quant_grad
.
apply
(
self
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
self
)
self
,
inputs
[
0
]
)
result
=
self
.
module
(
*
inputs
)
else
:
result
=
self
.
module
(
*
inputs
)
...
...
@@ -511,14 +511,12 @@ class Quantizer(Compressor):
# and it is trainable, therefore, it should be added to optimizer.
self
.
optimizer
.
add_param_group
({
"params"
:
wrapper
.
module
.
old_weight
})
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
"""
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
weight : Tensor
weight that needs to be quantized
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
...
...
@@ -720,11 +718,11 @@ class QuantGrad(torch.autograd.Function):
return
grad_output
@
staticmethod
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
**
kwargs
):
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
if
quant_type
==
QuantType
.
QUANT_INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
...
...
nni/compression/pytorch/quantization_speedup/__init__.py
0 → 100644
View file @
f0e3c584
from
.integrated_tensorrt
import
CalibrateType
,
ModelSpeedupTensorRT
\ No newline at end of file
nni/compression/pytorch/quantization_speedup/backend.py
0 → 100644
View file @
f0e3c584
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
class
BaseModelSpeedup
:
"""
Base speedup class for backend engine
"""
def
__init__
(
self
,
model
,
config
):
"""
Parameters
----------
model : pytorch model
The model to speed up by quantization.
config : dict
Config recording bit number and name of layers.
"""
self
.
model
=
model
self
.
config
=
config
def
inference
(
self
,
test_data
):
"""
This function should be overrided by subclass to provide inference ability,
which should return output and inference time.
Parameters
----------
test_data : numpy data
test data given to the inference engine
Returns
-------
numpy data
output data will be generated after inference
float
latency of such inference process
"""
raise
NotImplementedError
(
'Backend engine must overload inference()'
)
def
compress
(
self
):
"""
This function should be overrided by subclass to build inference
engine which will be used to process input data
"""
raise
NotImplementedError
(
'Backend engine must overload compress()'
)
def
export_quantized_model
(
self
,
path
):
"""
This function should be overrided by subclass to build inference
engine which will be used to process input data
"""
raise
NotImplementedError
(
'Backend engine must overload export_quantized_model()'
)
\ No newline at end of file
nni/compression/pytorch/quantization_speedup/calibrator.py
0 → 100644
View file @
f0e3c584
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
logging
import
tensorrt
as
trt
import
pycuda.driver
as
cuda
logger
=
logging
.
getLogger
(
__name__
)
class
Calibrator
(
trt
.
IInt8Calibrator
):
def
__init__
(
self
,
training_data
,
cache_file
,
batch_size
=
64
,
algorithm
=
trt
.
CalibrationAlgoType
.
ENTROPY_CALIBRATION_2
):
"""
Parameters
----------
training_data : numpy array
The data using to calibrate quantization model
cache_file : str
The path user want to store calibrate cache file
batch_size : int
The batch_size of calibrating process
algorithm : tensorrt.tensorrt.CalibrationAlgoType
The algorithms of calibrating contains LEGACY_CALIBRATION,
ENTROPY_CALIBRATION, ENTROPY_CALIBRATION_2, MINMAX_CALIBRATION.
Please refer to https://docs.nvidia.com/deeplearning/tensorrt/api/
python_api/infer/Int8/Calibrator.html for detail
"""
trt
.
IInt8Calibrator
.
__init__
(
self
)
self
.
algorithm
=
algorithm
self
.
cache_file
=
cache_file
self
.
data
=
training_data
self
.
batch_size
=
batch_size
self
.
current_index
=
0
# Allocate enough memory for a whole batch.
self
.
device_input
=
cuda
.
mem_alloc
(
self
.
data
[
0
].
nbytes
*
self
.
batch_size
)
def
get_algorithm
(
self
):
return
self
.
algorithm
def
get_batch_size
(
self
):
return
self
.
batch_size
def
get_batch
(
self
,
names
):
"""
This function is used to define the way of feeding calibrating data each batch.
Parameters
----------
names : str
The names of the network inputs for each object in the bindings array
Returns
-------
list
A list of device memory pointers set to the memory containing each network
input data, or an empty list if there are no more batches for calibration.
You can allocate these device buffers with pycuda, for example, and then
cast them to int to retrieve the pointer
"""
if
self
.
current_index
+
self
.
batch_size
>
self
.
data
.
shape
[
0
]:
return
None
current_batch
=
int
(
self
.
current_index
/
self
.
batch_size
)
if
current_batch
%
10
==
0
:
logger
.
info
(
"Calibrating batch %d, containing %d images"
,
current_batch
,
self
.
batch_size
)
batch
=
self
.
data
[
self
.
current_index
:
self
.
current_index
+
self
.
batch_size
].
ravel
()
cuda
.
memcpy_htod
(
self
.
device_input
,
batch
)
self
.
current_index
+=
self
.
batch_size
memory_pointers
=
[
self
.
device_input
]
return
memory_pointers
def
read_calibration_cache
(
self
):
"""
If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
Returns
-------
cache object
A cache object which contains calibration parameters for quantization
"""
if
os
.
path
.
exists
(
self
.
cache_file
):
with
open
(
self
.
cache_file
,
"rb"
)
as
f
:
return
f
.
read
()
def
write_calibration_cache
(
self
,
cache
):
"""
Write calibration cache to specific path.
Parameters
----------
cache : str
The calibration cache to write
"""
with
open
(
self
.
cache_file
,
"wb"
)
as
f
:
f
.
write
(
cache
)
\ No newline at end of file
nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py
0 → 100644
View file @
f0e3c584
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
onnx
import
onnx.numpy_helper
"""
The main function of this page is to convert pytorch model to onnx model.
Convertion from pytorch model to onnx model is primary so that a critical
problem is caused that Layer name of pytorch model fail to convert to onnx
layer name directly. To solve it, we wrap pytorch model in new wrapper which
multiply bit number and input before computation of each op. Only in this
way can onnx model get bit number of corresponded layer.
"""
class
LayernameModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_bit
)
->
None
:
"""
Parameters
----------
module : torch.nn.Module
Layer module of pytorch model
module_bit : int
Bit width setting for module
"""
super
().
__init__
()
self
.
module
=
module
self
.
module_bit
=
module_bit
def
forward
(
self
,
inputs
):
inputs
=
inputs
*
self
.
module_bit
inputs
=
self
.
module
(
inputs
)
return
inputs
def
_setattr
(
model
,
name
,
module
):
"""
Parameters
----------
model : pytorch model
The model to speed up by quantization
name : str
name of pytorch module
module : torch.nn.Module
Layer module of pytorch model
"""
name_list
=
name
.
split
(
"."
)
for
name
in
name_list
[:
-
1
]:
model
=
getattr
(
model
,
name
)
setattr
(
model
,
name_list
[
-
1
],
module
)
def
unwrapper
(
model_onnx
,
index2name
,
config
):
"""
Fill onnx config and remove wrapper node in onnx
Parameters
----------
model_onnx : onnx model
Onnx model which is converted from pytorch model
index2name : dict
Dictionary of layer index and name
config : dict
Config recording name of layers and calibration parameters
Returns
-------
onnx model
Onnx model which is converted from pytorch model
dict
The configuration of onnx model layers and calibration parameters
"""
# Support Gemm, Conv, Relu, Clip(Relu6) and Maxpool
support_op
=
[
'Gemm'
,
'Conv'
,
'Relu'
,
'Clip'
,
'MaxP'
]
idx
=
0
onnx_config
=
{}
while
idx
<
len
(
model_onnx
.
graph
.
node
):
nd
=
model_onnx
.
graph
.
node
[
idx
]
if
nd
.
name
[
0
:
4
]
in
support_op
and
idx
>
1
:
# Grad constant node and multiply node
const_nd
=
model_onnx
.
graph
.
node
[
idx
-
2
]
mul_nd
=
model_onnx
.
graph
.
node
[
idx
-
1
]
# Get index number which is transferred by constant node
index
=
int
(
onnx
.
numpy_helper
.
to_array
(
const_nd
.
attribute
[
0
].
t
))
if
index
!=
-
1
:
name
=
index2name
[
index
]
onnx_config
[
nd
.
name
]
=
config
[
name
]
nd
.
input
[
0
]
=
mul_nd
.
input
[
0
]
# Remove constant node and multiply node
model_onnx
.
graph
.
node
.
remove
(
const_nd
)
model_onnx
.
graph
.
node
.
remove
(
mul_nd
)
idx
=
idx
-
2
idx
=
idx
+
1
return
model_onnx
,
onnx_config
def
torch_to_onnx
(
model
,
config
,
input_shape
,
model_path
,
input_names
,
output_names
):
"""
Convert torch model to onnx model and get layer bit config of onnx model.
Parameters
----------
model : pytorch model
The model to speed up by quantization
config : dict
Config recording bit number and name of layers
input_shape : tuple
The input shape of model, shall pass it to torch.onnx.export
model_path : str
The path user want to store onnx model which is converted from pytorch model
input_names : list
Input name of onnx model providing for torch.onnx.export to generate onnx model
output_name : list
Output name of onnx model providing for torch.onnx.export to generate onnx model
Returns
-------
onnx model
Onnx model which is converted from pytorch model
dict
The configuration of onnx model layers and calibration parameters
"""
# Support Gemm, Conv, Relu, Clip(Relu6) and MaxPool
support_op
=
[
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Linear
,
torch
.
nn
.
ReLU
,
torch
.
nn
.
ReLU6
,
torch
.
nn
.
MaxPool2d
]
# Transfer bit number to onnx layer by using wrapper
index2name
=
{}
name2index
=
{}
if
config
is
not
None
:
for
i
,
name
in
enumerate
(
config
.
keys
()):
index2name
[
i
]
=
name
name2index
[
name
]
=
i
for
name
,
module
in
model
.
named_modules
():
if
config
is
not
None
and
name
in
config
:
assert
type
(
module
)
in
support_op
wrapper_module
=
LayernameModuleWrapper
(
module
,
name2index
[
name
])
_setattr
(
model
,
name
,
wrapper_module
)
elif
type
(
module
)
in
support_op
:
wrapper_module
=
LayernameModuleWrapper
(
module
,
-
1
)
_setattr
(
model
,
name
,
wrapper_module
)
# Convert torch model to onnx model and save it in model_path
dummy_input
=
torch
.
randn
(
input_shape
)
model
.
to
(
'cpu'
)
torch
.
onnx
.
export
(
model
,
dummy_input
,
model_path
,
verbose
=
False
,
input_names
=
input_names
,
output_names
=
output_names
,
export_params
=
True
)
# Load onnx model
model_onnx
=
onnx
.
load
(
model_path
)
model_onnx
,
onnx_config
=
unwrapper
(
model_onnx
,
index2name
,
config
)
onnx
.
save
(
model_onnx
,
model_path
)
onnx
.
checker
.
check_model
(
model_onnx
)
return
model_onnx
,
onnx_config
\ No newline at end of file
nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
0 → 100644
View file @
f0e3c584
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
time
import
logging
import
tensorrt
as
trt
import
numpy
as
np
import
torch
from
.
import
frontend_to_onnx
as
fonnx
from
.
import
calibrator
as
calibrator
from
.
import
trt_pycuda
as
common
from
.backend
import
BaseModelSpeedup
# TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
TRT_LOGGER
=
trt
.
Logger
()
logger
=
logging
.
getLogger
(
__name__
)
class
CalibrateType
:
LEGACY
=
trt
.
CalibrationAlgoType
.
LEGACY_CALIBRATION
ENTROPY
=
trt
.
CalibrationAlgoType
.
ENTROPY_CALIBRATION
ENTROPY2
=
trt
.
CalibrationAlgoType
.
ENTROPY_CALIBRATION_2
MINMAX
=
trt
.
CalibrationAlgoType
.
MINMAX_CALIBRATION
Precision_Dict
=
{
8
:
trt
.
float32
,
16
:
trt
.
float16
,
32
:
trt
.
float32
}
def
valid_config
(
config
=
None
):
"""
This function validates the bit setting configuration
"""
if
config
is
None
:
return
support_bit
=
[
8
,
16
,
32
]
for
name
in
config
.
keys
():
if
'weight_bit'
in
config
[
name
]:
w_bit
=
config
[
name
][
'weight_bit'
]
assert
w_bit
in
support_bit
,
"weight bit should be 8, 16, 32"
if
'activation_bit'
in
config
[
name
]:
a_bit
=
config
[
name
][
'activation_bit'
]
assert
a_bit
in
support_bit
,
"activation bit should be 8, 16, 32"
def
handle_gemm
(
network
,
layer_idx
,
config
):
"""
This function handles special gemm operation due to layer numbers of gemm changed during pytorch->onnx model convertion.
Parameters
----------
network : tensorrt.INetworkDefinition
Represents a TensorRT Network from which the Builder can build an Engine
layer_idx : int
layer index of gemm
config : dict
Config recording bit number and name of layers
"""
layer
=
network
.
get_layer
(
layer_idx
)
pre_layer
=
network
.
get_layer
(
layer_idx
-
1
)
next_layer
=
network
.
get_layer
(
layer_idx
+
1
)
# if weight bit exists, set three layers' precision,
# input tensor range and the first two layers' output type
if
'weight_bit'
in
config
[
layer
.
name
]:
assert
'tracked_min_input'
in
config
[
layer
.
name
]
assert
'tracked_max_input'
in
config
[
layer
.
name
]
w_bit
=
config
[
layer
.
name
][
'weight_bit'
]
tracked_min_input
=
config
[
layer
.
name
][
'tracked_min_input'
]
tracked_max_input
=
config
[
layer
.
name
][
'tracked_max_input'
]
# set three layers the same precision
layer
.
precision
=
Precision_Dict
[
w_bit
]
pre_layer
.
precision
=
Precision_Dict
[
w_bit
]
next_layer
.
precision
=
Precision_Dict
[
w_bit
]
# set the first two layers' output type
pre_layer
.
set_output_type
(
0
,
Precision_Dict
[
w_bit
])
layer
.
set_output_type
(
0
,
Precision_Dict
[
w_bit
])
pre_in_tensor
=
pre_layer
.
get_input
(
0
)
in_tensor
=
layer
.
get_input
(
0
)
next_in_tensor
=
next_layer
.
get_input
(
0
)
# set three layers' input tensor range
pre_in_tensor
.
dynamic_range
=
(
tracked_min_input
,
tracked_max_input
)
in_tensor
.
dynamic_range
=
(
tracked_min_input
,
tracked_max_input
)
next_in_tensor
.
dynamic_range
=
(
tracked_min_input
,
tracked_max_input
)
# if activation bit exists, set the last layer's output type output tensor range
if
'activation_bit'
in
config
[
layer
.
name
]:
assert
'tracked_min_activation'
in
config
[
layer
.
name
]
assert
'tracked_max_activation'
in
config
[
layer
.
name
]
a_bit
=
config
[
layer
.
name
][
'activation_bit'
]
tracked_min_activation
=
config
[
layer
.
name
][
'tracked_min_activation'
]
tracked_max_activation
=
config
[
layer
.
name
][
'tracked_max_activation'
]
# set the last layer's output type
next_layer
.
set_output_type
(
0
,
Precision_Dict
[
a_bit
])
next_out_tensor
=
next_layer
.
get_output
(
0
)
# set the last layer's output tensor range
next_out_tensor
.
dynamic_range
=
(
tracked_min_activation
,
tracked_max_activation
)
def
build_engine
(
model_file
,
config
=
None
,
extra_layer_bit
=
32
,
strict_datatype
=
False
,
calib
=
None
):
"""
This function builds an engine from an onnx model with calibration process.
Parameters
----------
model_file : str
The path of onnx model
config : dict
Config recording bit number and name of layers
extra_layer_bit : int
Other layers which are not in config will be quantized to corresponding bit number
strict_datatype : bool
Whether constrain layer bit to the number given in config or not. If true, all the layer
will be set to given bit strictly. Otherwise, these layers will be set automatically by
tensorrt
calib : numpy array
The data using to calibrate quantization model
Returns
-------
tensorrt.ICudaEngine
An ICudaEngine for executing inference on a built network
"""
with
trt
.
Builder
(
TRT_LOGGER
)
as
builder
,
builder
.
create_network
(
common
.
EXPLICIT_BATCH
)
as
network
,
\
trt
.
OnnxParser
(
network
,
TRT_LOGGER
)
as
parser
:
# Attention that, builder should be set to 1 because of the implementation of allocate_buffer
builder
.
max_batch_size
=
1
builder
.
max_workspace_size
=
common
.
GiB
(
4
)
if
extra_layer_bit
==
32
and
config
is
None
:
pass
elif
extra_layer_bit
==
16
and
config
is
None
:
builder
.
fp16_mode
=
True
elif
extra_layer_bit
==
8
and
config
is
None
:
# entire model in 8bit mode
builder
.
int8_mode
=
True
else
:
builder
.
int8_mode
=
True
builder
.
fp16_mode
=
True
builder
.
strict_type_constraints
=
strict_datatype
valid_config
(
config
)
# Parse onnx model
with
open
(
model_file
,
'rb'
)
as
model
:
if
not
parser
.
parse
(
model
.
read
()):
logger
.
error
(
'ERROR: Fail to parse the ONNX file.'
)
for
error
in
range
(
parser
.
num_errors
):
logger
.
error
(
parser
.
get_error
(
error
))
return
None
if
calib
is
not
None
:
builder
.
int8_calibrator
=
calib
# This design may not be correct if output more than one
for
i
in
range
(
network
.
num_layers
):
if
config
is
None
:
break
layer
=
network
.
get_layer
(
i
)
if
layer
.
name
in
config
:
w_bit
=
config
[
layer
.
name
][
'weight_bit'
]
a_bit
=
config
[
layer
.
name
][
'activation_bit'
]
layer
.
precision
=
Precision_Dict
[
w_bit
]
layer
.
set_output_type
(
0
,
Precision_Dict
[
a_bit
])
else
:
# This implementation may be incorrect when output number > 1
for
i
in
range
(
network
.
num_layers
):
if
config
is
None
:
# no low bit layer need to be set, keep original model
break
layer
=
network
.
get_layer
(
i
)
if
layer
.
name
not
in
config
:
continue
# layer numbers of gemm changed during pytorch->onnx model convertion, need special handle
if
layer
.
name
[
0
:
4
]
==
"Gemm"
:
handle_gemm
(
network
,
i
,
config
)
continue
# If weight_bit exists in config, set layer precision and layer's input tensor dynamic range.
if
'weight_bit'
in
config
[
layer
.
name
]:
assert
'tracked_min_input'
in
config
[
layer
.
name
]
assert
'tracked_max_input'
in
config
[
layer
.
name
]
w_bit
=
config
[
layer
.
name
][
'weight_bit'
]
tracked_min_input
=
config
[
layer
.
name
][
'tracked_min_input'
]
tracked_max_input
=
config
[
layer
.
name
][
'tracked_max_input'
]
layer
.
precision
=
Precision_Dict
[
w_bit
]
in_tensor
=
layer
.
get_input
(
0
)
in_tensor
.
dynamic_range
=
(
tracked_min_input
,
tracked_max_input
)
# If activation exists in config, set layer output type and layer's output tensor dynamic range.
if
'activation_bit'
in
config
[
layer
.
name
]:
assert
'tracked_min_activation'
in
config
[
layer
.
name
]
assert
'tracked_max_activation'
in
config
[
layer
.
name
]
a_bit
=
config
[
layer
.
name
][
'activation_bit'
]
tracked_min_activation
=
config
[
layer
.
name
][
'tracked_min_activation'
]
tracked_max_activation
=
config
[
layer
.
name
][
'tracked_max_activation'
]
layer
.
set_output_type
(
0
,
Precision_Dict
[
a_bit
])
out_tensor
=
layer
.
get_output
(
0
)
out_tensor
.
dynamic_range
=
(
tracked_min_activation
,
tracked_max_activation
)
# Build engine and do int8 calibration.
engine
=
builder
.
build_cuda_engine
(
network
)
return
engine
class
ModelSpeedupTensorRT
(
BaseModelSpeedup
):
def
__init__
(
self
,
model
,
input_shape
,
config
=
None
,
onnx_path
=
"default_model.onnx"
,
extra_layer_bit
=
32
,
strict_datatype
=
True
,
calibrate_type
=
CalibrateType
.
ENTROPY2
,
calib_data_loader
=
None
,
calibration_cache
=
"calibration.cache"
,
batchsize
=
1
,
input_names
=
[
"actual_input_1"
],
output_names
=
[
"output1"
]):
"""
Parameters
----------
model : pytorch model
The model to speed up by quantization.
input_shape : tuple
The input shape of model, shall pass it to torch.onnx.export.
config : dict
Config recording bit number and name of layers.
onnx_path : str
The path user want to store onnx model which is converted from pytorch model.
extra_layer_bit : int
Other layers which are not in config will be quantized to corresponding bit number.
strict_datatype : bool
Whether constrain layer bit to the number given in config or not. If true, all the layer
will be set to given bit strictly. Otherwise, these layers will be set automatically by
tensorrt.
calibrate_type : tensorrt.tensorrt.CalibrationAlgoType
The algorithm of calibrating. Please refer to https://docs.nvidia.com/deeplearning/
tensorrt/api/python_api/infer/Int8/Calibrator.html for detail
calibrate_data : numpy array
The data using to calibrate quantization model
calibration_cache : str
The path user want to store calibrate cache file
batchsize : int
The batch size of calibration and inference
input_names : list
Input name of onnx model providing for torch.onnx.export to generate onnx model
output_name : list
Output name of onnx model providing for torch.onnx.export to generate onnx model
"""
super
().
__init__
(
model
,
config
)
self
.
model
=
model
self
.
onnx_path
=
onnx_path
self
.
input_shape
=
input_shape
self
.
config
=
config
self
.
extra_layer_bit
=
extra_layer_bit
self
.
strict_datatype
=
strict_datatype
self
.
calibrate_type
=
calibrate_type
self
.
calib_data_loader
=
calib_data_loader
self
.
calibration_cache
=
calibration_cache
self
.
batchsize
=
batchsize
self
.
input_names
=
input_names
self
.
output_names
=
output_names
self
.
context
=
None
self
.
onnx_config
=
{}
def
compress
(
self
):
"""
Get onnx config and build tensorrt engine.
"""
assert
self
.
model
is
not
None
assert
self
.
onnx_path
is
not
None
assert
self
.
input_shape
is
not
None
# Convert pytorch model to onnx model and save onnx model in onnx_path
_
,
self
.
onnx_config
=
fonnx
.
torch_to_onnx
(
self
.
model
,
self
.
config
,
input_shape
=
self
.
input_shape
,
model_path
=
self
.
onnx_path
,
input_names
=
self
.
input_names
,
output_names
=
self
.
output_names
)
if
self
.
calib_data_loader
is
not
None
:
assert
self
.
calibrate_type
is
not
None
context
=
self
.
_tensorrt_build_withcalib
(
self
.
onnx_path
)
else
:
context
=
self
.
_tensorrt_build_withoutcalib
(
self
.
onnx_path
)
self
.
context
=
context
def
_tensorrt_build_withcalib
(
self
,
onnx_path
):
"""
Convert pytorch tensor to numpy darray
Parameters
----------
onnx_path : str
The path of onnx model
Returns
-------
tensorrt.IExecutionContext
Context for executing inference using an ICudaEngine
"""
calib_data
=
None
if
type
(
self
.
calib_data_loader
)
==
torch
.
utils
.
data
.
dataloader
.
DataLoader
:
calib_data_set
=
[]
for
data
,
_
in
self
.
calib_data_loader
:
calib_data_set
.
append
(
data
)
calib_data
=
np
.
concatenate
(
calib_data_set
)
elif
type
(
self
.
calib_data_loader
)
==
torch
.
Tensor
:
calib_data
=
self
.
calib_data_loader
.
numpy
()
else
:
raise
ValueError
(
"Not support calibration datatype"
)
calib
=
calibrator
.
Calibrator
(
calib_data
,
self
.
calibration_cache
,
self
.
batchsize
,
self
.
calibrate_type
)
# build inference engine with calibration
engine
=
build_engine
(
onnx_path
,
self
.
onnx_config
,
self
.
extra_layer_bit
,
self
.
strict_datatype
,
calib
)
return
engine
.
create_execution_context
()
def
_tensorrt_build_withoutcalib
(
self
,
onnx_path
):
"""
Build inference engine without calibration
Parameters
----------
onnx_path : str
The path of onnx model
Returns
-------
tensorrt.IExecutionContext
Context for executing inference using an ICudaEngine
"""
engine
=
build_engine
(
onnx_path
,
self
.
onnx_config
,
self
.
extra_layer_bit
,
self
.
strict_datatype
)
return
engine
.
create_execution_context
()
def
inference
(
self
,
test_data
):
"""
Do inference by tensorrt builded engine.
Parameters
----------
test_data : pytorch tensor
Model input tensor
"""
# convert pytorch tensor to numpy darray
test_data
=
test_data
.
numpy
()
# Numpy dtype should be float32
assert
test_data
.
dtype
==
np
.
float32
elapsed_time
=
0
inputs
,
outputs
,
bindings
,
stream
=
common
.
allocate_buffers
(
self
.
context
.
engine
)
result
=
[]
for
start_idx
in
range
(
0
,
test_data
.
shape
[
0
],
self
.
batchsize
):
# If the number of images in the test set is not divisible by the batch size, the last batch will be smaller.
# This logic is used for handling that case.
end_idx
=
min
(
start_idx
+
self
.
batchsize
,
test_data
.
shape
[
0
])
effective_batch_size
=
end_idx
-
start_idx
# Do inference for every batch.
inputs
[
0
].
host
=
test_data
[
start_idx
:
start_idx
+
effective_batch_size
]
t1
=
time
.
time
()
[
output
]
=
common
.
do_inference_v2
(
self
.
context
,
bindings
=
bindings
,
inputs
=
inputs
,
outputs
=
outputs
,
stream
=
stream
)
elapsed_time
+=
time
.
time
()
-
t1
shape
=
output
.
shape
[
0
]
output
=
output
[
0
:
int
(
shape
*
effective_batch_size
/
self
.
batchsize
)].
reshape
(
effective_batch_size
,
-
1
)
result
.
append
(
output
.
copy
())
# Use argmax to get predictions and then check accuracy
# convert numpy darray to pytorch tensor
result
=
torch
.
Tensor
(
np
.
concatenate
(
result
))
return
result
,
elapsed_time
def
export_quantized_model
(
self
,
path
):
"""
Export TensorRT quantized model engine which only can be loaded by TensorRT deserialize API.
Parameters
----------
path : str
The path of export model
"""
assert
path
is
not
None
with
open
(
path
,
"wb"
)
as
f
:
f
.
write
(
self
.
context
.
engine
.
serialize
())
logger
.
info
(
"TensorRT engine has been saved to %s"
,
path
)
def
load_quantized_model
(
self
,
path
):
"""
Load TensorRT quantized model engine from specific path.
Parameters
----------
path : str
The path of export model
"""
assert
path
is
not
None
with
open
(
path
,
"rb"
)
as
f
,
trt
.
Runtime
(
TRT_LOGGER
)
as
runtime
:
engine
=
runtime
.
deserialize_cuda_engine
(
f
.
read
())
self
.
context
=
engine
.
create_execution_context
()
logger
.
info
(
"Load TensorRT engine from %s successfully."
,
path
)
\ No newline at end of file
nni/compression/pytorch/quantization_speedup/trt_pycuda.py
0 → 100644
View file @
f0e3c584
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
pycuda.driver
as
cuda
import
pycuda.autoinit
# pylint: disable=unused-import
import
tensorrt
as
trt
EXPLICIT_BATCH
=
1
def
GiB
(
val
):
return
val
*
1
<<
30
# Simple helper data class that's a little nicer to use than a 2-tuple.
class
HostDeviceMem
(
object
):
def
__init__
(
self
,
host_mem
,
device_mem
):
"""
This function builds an engine from an onnx model with calibration process.
Parameters
----------
host_mem : host memory
Memory buffers of host
device_mem : device memory
Memory buffers of device
"""
self
.
host
=
host_mem
self
.
device
=
device_mem
def
__str__
(
self
):
return
"Host:
\n
"
+
str
(
self
.
host
)
+
"
\n
Device:
\n
"
+
str
(
self
.
device
)
def
__repr__
(
self
):
return
self
.
__str__
()
def
allocate_buffers
(
engine
):
"""
Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
Parameters
----------
engine : tensorrt.ICudaEngine
An ICudaEngine for executing inference on a built network
Returns
-------
list
All input HostDeviceMem of an engine
list
All output HostDeviceMem of an engine
GPU bindings
Device bindings
GPU stream
A stream is a sequence of commands (possibly issued by different host threads) that execute in order
"""
inputs
=
[]
outputs
=
[]
bindings
=
[]
stream
=
cuda
.
Stream
()
for
binding
in
engine
:
size
=
trt
.
volume
(
engine
.
get_binding_shape
(
binding
))
*
engine
.
max_batch_size
dtype
=
trt
.
nptype
(
engine
.
get_binding_dtype
(
binding
))
# Allocate host and device buffers
host_mem
=
cuda
.
pagelocked_empty
(
size
,
dtype
)
device_mem
=
cuda
.
mem_alloc
(
host_mem
.
nbytes
)
# Append the device buffer to device bindings.
bindings
.
append
(
int
(
device_mem
))
# Append to the appropriate list.
if
engine
.
binding_is_input
(
binding
):
inputs
.
append
(
HostDeviceMem
(
host_mem
,
device_mem
))
else
:
outputs
.
append
(
HostDeviceMem
(
host_mem
,
device_mem
))
return
inputs
,
outputs
,
bindings
,
stream
# This function is generalized for multiple inputs/outputs for full dimension networks.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def
do_inference_v2
(
context
,
bindings
,
inputs
,
outputs
,
stream
):
# Transfer input data to the GPU.
[
cuda
.
memcpy_htod_async
(
inp
.
device
,
inp
.
host
,
stream
)
for
inp
in
inputs
]
# Run inference.
context
.
execute_async_v2
(
bindings
=
bindings
,
stream_handle
=
stream
.
handle
)
# Transfer predictions back from the GPU.
[
cuda
.
memcpy_dtoh_async
(
out
.
host
,
out
.
device
,
stream
)
for
out
in
outputs
]
# Synchronize the stream
stream
.
synchronize
()
# Return only the host outputs.
return
[
out
.
host
for
out
in
outputs
]
\ No newline at end of file
pylintrc
View file @
f0e3c584
...
...
@@ -45,6 +45,6 @@ enable= unused-wildcard-import,
ignore-patterns=test*
# List of members which are set dynamically and missed by pylint inference
generated-members=numpy.*,torch.*,tensorflow.*
generated-members=numpy.*,torch.*,tensorflow.*
,pycuda.*,tensorrt.*
ignored-modules=tensorflow,_winapi,msvcrt
ignored-modules=tensorflow,_winapi,msvcrt
,tensorrt,pycuda
test/ut/sdk/test_compressor_torch.py
View file @
f0e3c584
...
...
@@ -239,15 +239,16 @@ class CompressorTestCase(TestCase):
# test quantize
# range not including 0
eps
=
1e-7
input
=
torch
.
tensor
([[
0
,
4
],
[
2
,
1
]]).
float
()
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
old_weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
==
0
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
old_weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
in
(
42
,
43
)
# test value of weight and bias after quantization
...
...
@@ -257,7 +258,7 @@ class CompressorTestCase(TestCase):
bias_valid
=
torch
.
tensor
([
2.3432
,
3.4342
,
1.3414
,
5.2341
])
model
.
conv2
.
module
.
old_weight
.
data
=
weight
model
.
conv2
.
module
.
bias
.
data
=
bias
quantizer
.
quantize_weight
(
model
.
conv2
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
torch
.
all
(
torch
.
isclose
(
model
.
conv2
.
module
.
weight
.
data
,
weight_valid
,
rtol
=
1e-4
))
assert
torch
.
all
(
torch
.
isclose
(
model
.
conv2
.
module
.
bias
.
data
,
bias_valid
,
rtol
=
1e-7
))
...
...
@@ -265,14 +266,14 @@ class CompressorTestCase(TestCase):
eps
=
1e-7
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_
biased
,
0
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_
biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_
activation
,
0
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_
activation
,
0.002
,
abs_tol
=
eps
)
quantizer
.
step_with_optimizer
()
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_
biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_
biased
,
0.00998
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_
activation
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_
activation
,
0.00998
,
abs_tol
=
eps
)
def
test_torch_quantizer_export
(
self
):
config_list_qat
=
[{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment