Unverified Commit 13888df2 authored by lizz's avatar lizz Committed by GitHub
Browse files

Fix typos (#1041)



* Fix typos
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Add deprecation warning
Signed-off-by: default avatarlizz <lizz@sensetime.com>
parent 9d1436fb
# flake8: noqa # flake8: noqa
from .init_plugins import is_tensorrt_plugin_loaded, load_tensorrt_plugin from .init_plugins import is_tensorrt_plugin_loaded, load_tensorrt_plugin
from .tensorrt_utils import (TRTWraper, load_trt_engine, onnx2trt, from .tensorrt_utils import (TRTWraper, TRTWrapper, load_trt_engine, onnx2trt,
save_trt_engine) save_trt_engine)
# load tensorrt plugin lib # load tensorrt plugin lib
load_tensorrt_plugin() load_tensorrt_plugin()
__all__ = [ __all__ = [
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper', 'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWrapper',
'is_tensorrt_plugin_loaded' 'TRTWraper', 'is_tensorrt_plugin_loaded'
] ]
import warnings
import numpy as np import numpy as np
import onnx import onnx
import tensorrt as trt import tensorrt as trt
...@@ -40,7 +42,7 @@ def preprocess_onnx(onnx_model): ...@@ -40,7 +42,7 @@ def preprocess_onnx(onnx_model):
elif name in init_dict: elif name in init_dict:
raw_data = init_dict[name].raw_data raw_data = init_dict[name].raw_data
else: else:
raise ValueError(f'{name} not found in node or initilizer.') raise ValueError(f'{name} not found in node or initializer.')
return np.frombuffer(raw_data, typ).item() return np.frombuffer(raw_data, typ).item()
nrof_node = len(nodes) nrof_node = len(nodes)
...@@ -225,8 +227,8 @@ def torch_device_from_trt(device): ...@@ -225,8 +227,8 @@ def torch_device_from_trt(device):
return TypeError('%s is not supported by torch' % device) return TypeError('%s is not supported by torch' % device)
class TRTWraper(torch.nn.Module): class TRTWrapper(torch.nn.Module):
"""TensorRT engine Wraper. """TensorRT engine Wrapper.
Arguments: Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to wrap engine (tensorrt.ICudaEngine): TensorRT engine to wrap
...@@ -239,7 +241,7 @@ class TRTWraper(torch.nn.Module): ...@@ -239,7 +241,7 @@ class TRTWraper(torch.nn.Module):
""" """
def __init__(self, engine, input_names=None, output_names=None): def __init__(self, engine, input_names=None, output_names=None):
super(TRTWraper, self).__init__() super(TRTWrapper, self).__init__()
self.engine = engine self.engine = engine
if isinstance(self.engine, str): if isinstance(self.engine, str):
self.engine = load_trt_engine(engine) self.engine = load_trt_engine(engine)
...@@ -247,7 +249,7 @@ class TRTWraper(torch.nn.Module): ...@@ -247,7 +249,7 @@ class TRTWraper(torch.nn.Module):
if not isinstance(self.engine, trt.ICudaEngine): if not isinstance(self.engine, trt.ICudaEngine):
raise TypeError('engine should be str or trt.ICudaEngine') raise TypeError('engine should be str or trt.ICudaEngine')
self._register_state_dict_hook(TRTWraper._on_state_dict) self._register_state_dict_hook(TRTWrapper._on_state_dict)
self.context = self.engine.create_execution_context() self.context = self.engine.create_execution_context()
# get input and output names from engine # get input and output names from engine
...@@ -310,3 +312,11 @@ class TRTWraper(torch.nn.Module): ...@@ -310,3 +312,11 @@ class TRTWraper(torch.nn.Module):
torch.cuda.current_stream().cuda_stream) torch.cuda.current_stream().cuda_stream)
return outputs return outputs
class TRTWraper(TRTWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn('TRTWraper will be deprecated in'
' future. Please use TRTWrapper instead')
...@@ -266,7 +266,7 @@ def requires_executable(prerequisites): ...@@ -266,7 +266,7 @@ def requires_executable(prerequisites):
def deprecated_api_warning(name_dict, cls_name=None): def deprecated_api_warning(name_dict, cls_name=None):
"""A decorator to check if some argments are deprecate and try to replace """A decorator to check if some arguments are deprecate and try to replace
deprecate src_arg_name to dst_arg_name. deprecate src_arg_name to dst_arg_name.
Args: Args:
......
...@@ -15,7 +15,7 @@ def imshow(img, win_name='', wait_time=0): ...@@ -15,7 +15,7 @@ def imshow(img, win_name='', wait_time=0):
wait_time (int): Value of waitKey param. wait_time (int): Value of waitKey param.
""" """
cv2.imshow(win_name, imread(img)) cv2.imshow(win_name, imread(img))
if wait_time == 0: # prevent from hangning if windows was closed if wait_time == 0: # prevent from hanging if windows was closed
while True: while True:
ret = cv2.waitKey(1) ret = cv2.waitKey(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment