Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a4760ce8
"...composable_kernel_onnxruntime.git" did not exist on "e00149ac677b490ee7011d3894a37233ccacae93"
Unverified
Commit
a4760ce8
authored
Jul 09, 2021
by
lin bin
Committed by
GitHub
Jul 09, 2021
Browse files
[Quantization speedup]support TensorRT8.0.0 (#3866)
parent
4b1f46a3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
10 deletions
+42
-10
docs/en_US/Compression/QuantizationSpeedup.rst
docs/en_US/Compression/QuantizationSpeedup.rst
+4
-0
nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
...ssion/pytorch/quantization_speedup/integrated_tensorrt.py
+38
-10
No files found.
docs/en_US/Compression/QuantizationSpeedup.rst
View file @
a4760ce8
...
@@ -50,6 +50,10 @@ CUDA version >= 11.0
...
@@ -50,6 +50,10 @@ CUDA version >= 11.0
TensorRT
version
>=
7.2
TensorRT
version
>=
7.2
Note
*
If
you
haven
't installed TensorRT before or use the old version, please refer to `TensorRT Installation Guide <https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html>`__\
Usage
Usage
-----
-----
quantization aware training:
quantization aware training:
...
...
nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
View file @
a4760ce8
...
@@ -12,7 +12,8 @@ from . import calibrator as calibrator
...
@@ -12,7 +12,8 @@ from . import calibrator as calibrator
from
.
import
trt_pycuda
as
common
from
.
import
trt_pycuda
as
common
from
.backend
import
BaseModelSpeedup
from
.backend
import
BaseModelSpeedup
# TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
TRT8
=
8
TRT7
=
7
TRT_LOGGER
=
trt
.
Logger
()
TRT_LOGGER
=
trt
.
Logger
()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -120,18 +121,39 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
...
@@ -120,18 +121,39 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
An ICudaEngine for executing inference on a built network
An ICudaEngine for executing inference on a built network
"""
"""
with
trt
.
Builder
(
TRT_LOGGER
)
as
builder
,
builder
.
create_network
(
common
.
EXPLICIT_BATCH
)
as
network
,
\
with
trt
.
Builder
(
TRT_LOGGER
)
as
builder
,
builder
.
create_network
(
common
.
EXPLICIT_BATCH
)
as
network
,
\
trt
.
OnnxParser
(
network
,
TRT_LOGGER
)
as
parser
:
trt
.
OnnxParser
(
network
,
TRT_LOGGER
)
as
parser
,
builder
.
create_builder_config
()
as
trt_config
:
# Attention that, builder should be set to 1 because of the implementation of allocate_buffer
# Attention that, builder should be set to 1 because of the implementation of allocate_buffer
trt_version
=
int
(
trt
.
__version__
[
0
])
assert
trt_version
==
TRT8
or
trt_version
==
TRT7
,
"Version of TensorRT is too old, please
\
update TensorRT to version >= 7.0"
if
trt_version
==
TRT7
:
logger
.
warning
(
"TensorRT7 is deprecated and may be removed in the following release."
)
builder
.
max_batch_size
=
1
builder
.
max_batch_size
=
1
if
trt_version
==
TRT8
:
trt_config
.
max_workspace_size
=
common
.
GiB
(
4
)
else
:
builder
.
max_workspace_size
=
common
.
GiB
(
4
)
builder
.
max_workspace_size
=
common
.
GiB
(
4
)
if
extra_layer_bit
==
32
and
config
is
None
:
if
extra_layer_bit
==
32
and
config
is
None
:
pass
pass
elif
extra_layer_bit
==
16
and
config
is
None
:
elif
extra_layer_bit
==
16
and
config
is
None
:
if
trt_version
==
TRT8
:
trt_config
.
set_flag
(
trt
.
BuilderFlag
.
FP16
)
else
:
builder
.
fp16_mode
=
True
builder
.
fp16_mode
=
True
elif
extra_layer_bit
==
8
and
config
is
None
:
elif
extra_layer_bit
==
8
and
config
is
None
:
# entire model in 8bit mode
# entire model in 8bit mode
if
trt_version
==
TRT8
:
trt_config
.
set_flag
(
trt
.
BuilderFlag
.
INT8
)
else
:
builder
.
int8_mode
=
True
builder
.
int8_mode
=
True
else
:
if
trt_version
==
TRT8
:
trt_config
.
set_flag
(
trt
.
BuilderFlag
.
INT8
)
trt_config
.
set_flag
(
trt
.
BuilderFlag
.
FP16
)
if
strict_datatype
:
trt_config
.
set_flag
(
trt
.
BuilderFlag
.
STRICT_TYPES
)
else
:
else
:
builder
.
int8_mode
=
True
builder
.
int8_mode
=
True
builder
.
fp16_mode
=
True
builder
.
fp16_mode
=
True
...
@@ -148,6 +170,9 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
...
@@ -148,6 +170,9 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
return
None
return
None
if
calib
is
not
None
:
if
calib
is
not
None
:
if
trt_version
==
TRT8
:
trt_config
.
int8_calibrator
=
calib
else
:
builder
.
int8_calibrator
=
calib
builder
.
int8_calibrator
=
calib
# This design may not be correct if output more than one
# This design may not be correct if output more than one
for
i
in
range
(
network
.
num_layers
):
for
i
in
range
(
network
.
num_layers
):
...
@@ -196,7 +221,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
...
@@ -196,7 +221,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
out_tensor
.
dynamic_range
=
(
tracked_min_activation
,
tracked_max_activation
)
out_tensor
.
dynamic_range
=
(
tracked_min_activation
,
tracked_max_activation
)
# Build engine and do int8 calibration.
# Build engine and do int8 calibration.
engine
=
builder
.
build_cuda_engine
(
network
)
if
trt_version
==
TRT8
:
engine
=
builder
.
build_engine
(
network
,
trt_config
)
else
:
engine
.
builder
.
build_cuda_engine
(
network
)
return
engine
return
engine
class
ModelSpeedupTensorRT
(
BaseModelSpeedup
):
class
ModelSpeedupTensorRT
(
BaseModelSpeedup
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment