"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "297a1e2e4d3bb7c41624899b8eeb4df83aa2ca18"
Unverified Commit a4760ce8 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

[Quantization speedup]support TensorRT8.0.0 (#3866)

parent 4b1f46a3
...@@ -50,6 +50,10 @@ CUDA version >= 11.0 ...@@ -50,6 +50,10 @@ CUDA version >= 11.0
TensorRT version >= 7.2 TensorRT version >= 7.2
Note
* If you haven't installed TensorRT before or use the old version, please refer to `TensorRT Installation Guide <https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html>`__\
Usage Usage
----- -----
quantization aware training: quantization aware training:
......
...@@ -12,7 +12,8 @@ from . import calibrator as calibrator ...@@ -12,7 +12,8 @@ from . import calibrator as calibrator
from . import trt_pycuda as common from . import trt_pycuda as common
from .backend import BaseModelSpeedup from .backend import BaseModelSpeedup
# TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) TRT8 = 8
TRT7 = 7
TRT_LOGGER = trt.Logger() TRT_LOGGER = trt.Logger()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -120,22 +121,43 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa ...@@ -120,22 +121,43 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
An ICudaEngine for executing inference on a built network An ICudaEngine for executing inference on a built network
""" """
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, \ with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser: trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as trt_config:
# Attention that, builder should be set to 1 because of the implementation of allocate_buffer # Attention that, builder should be set to 1 because of the implementation of allocate_buffer
trt_version = int(trt.__version__[0])
assert trt_version == TRT8 or trt_version == TRT7, "Version of TensorRT is too old, please \
update TensorRT to version >= 7.0"
if trt_version == TRT7:
logger.warning("TensorRT7 is deprecated and may be removed in the following release.")
builder.max_batch_size = 1 builder.max_batch_size = 1
builder.max_workspace_size = common.GiB(4) if trt_version == TRT8:
trt_config.max_workspace_size = common.GiB(4)
else:
builder.max_workspace_size = common.GiB(4)
if extra_layer_bit == 32 and config is None: if extra_layer_bit == 32 and config is None:
pass pass
elif extra_layer_bit == 16 and config is None: elif extra_layer_bit == 16 and config is None:
builder.fp16_mode = True if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.FP16)
else:
builder.fp16_mode = True
elif extra_layer_bit == 8 and config is None: elif extra_layer_bit == 8 and config is None:
# entire model in 8bit mode # entire model in 8bit mode
builder.int8_mode = True if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.INT8)
else:
builder.int8_mode = True
else: else:
builder.int8_mode = True if trt_version == TRT8:
builder.fp16_mode = True trt_config.set_flag(trt.BuilderFlag.INT8)
builder.strict_type_constraints = strict_datatype trt_config.set_flag(trt.BuilderFlag.FP16)
if strict_datatype:
trt_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
else:
builder.int8_mode = True
builder.fp16_mode = True
builder.strict_type_constraints = strict_datatype
valid_config(config) valid_config(config)
...@@ -148,7 +170,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa ...@@ -148,7 +170,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
return None return None
if calib is not None: if calib is not None:
builder.int8_calibrator = calib if trt_version == TRT8:
trt_config.int8_calibrator = calib
else:
builder.int8_calibrator = calib
# This design may not be correct if output more than one # This design may not be correct if output more than one
for i in range(network.num_layers): for i in range(network.num_layers):
if config is None: if config is None:
...@@ -196,7 +221,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa ...@@ -196,7 +221,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
out_tensor.dynamic_range = (tracked_min_activation, tracked_max_activation) out_tensor.dynamic_range = (tracked_min_activation, tracked_max_activation)
# Build engine and do int8 calibration. # Build engine and do int8 calibration.
engine = builder.build_cuda_engine(network) if trt_version == TRT8:
engine = builder.build_engine(network, trt_config)
else:
engine.builder.build_cuda_engine(network)
return engine return engine
class ModelSpeedupTensorRT(BaseModelSpeedup): class ModelSpeedupTensorRT(BaseModelSpeedup):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment