Unverified Commit 13888df2 authored by lizz's avatar lizz Committed by GitHub
Browse files

Fix typos (#1041)



* Fix typos
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Add deprecation warning
Signed-off-by: default avatarlizz <lizz@sensetime.com>
parent 9d1436fb
# flake8: noqa
from .init_plugins import is_tensorrt_plugin_loaded, load_tensorrt_plugin
from .tensorrt_utils import (TRTWraper, load_trt_engine, onnx2trt,
from .tensorrt_utils import (TRTWraper, TRTWrapper, load_trt_engine, onnx2trt,
save_trt_engine)
# load tensorrt plugin lib
load_tensorrt_plugin()
__all__ = [
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
'is_tensorrt_plugin_loaded'
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWrapper',
'TRTWraper', 'is_tensorrt_plugin_loaded'
]
import warnings
import numpy as np
import onnx
import tensorrt as trt
......@@ -40,7 +42,7 @@ def preprocess_onnx(onnx_model):
elif name in init_dict:
raw_data = init_dict[name].raw_data
else:
raise ValueError(f'{name} not found in node or initilizer.')
raise ValueError(f'{name} not found in node or initializer.')
return np.frombuffer(raw_data, typ).item()
nrof_node = len(nodes)
......@@ -225,8 +227,8 @@ def torch_device_from_trt(device):
return TypeError('%s is not supported by torch' % device)
class TRTWraper(torch.nn.Module):
"""TensorRT engine Wraper.
class TRTWrapper(torch.nn.Module):
"""TensorRT engine Wrapper.
Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to wrap
......@@ -239,7 +241,7 @@ class TRTWraper(torch.nn.Module):
"""
def __init__(self, engine, input_names=None, output_names=None):
super(TRTWraper, self).__init__()
super(TRTWrapper, self).__init__()
self.engine = engine
if isinstance(self.engine, str):
self.engine = load_trt_engine(engine)
......@@ -247,7 +249,7 @@ class TRTWraper(torch.nn.Module):
if not isinstance(self.engine, trt.ICudaEngine):
raise TypeError('engine should be str or trt.ICudaEngine')
self._register_state_dict_hook(TRTWraper._on_state_dict)
self._register_state_dict_hook(TRTWrapper._on_state_dict)
self.context = self.engine.create_execution_context()
# get input and output names from engine
......@@ -310,3 +312,11 @@ class TRTWraper(torch.nn.Module):
torch.cuda.current_stream().cuda_stream)
return outputs
class TRTWraper(TRTWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn('TRTWraper will be deprecated in'
' future. Please use TRTWrapper instead')
......@@ -266,7 +266,7 @@ def requires_executable(prerequisites):
def deprecated_api_warning(name_dict, cls_name=None):
"""A decorator to check if some argments are deprecate and try to replace
"""A decorator to check if some arguments are deprecate and try to replace
deprecate src_arg_name to dst_arg_name.
Args:
......
......@@ -15,7 +15,7 @@ def imshow(img, win_name='', wait_time=0):
wait_time (int): Value of waitKey param.
"""
cv2.imshow(win_name, imread(img))
if wait_time == 0: # prevent from hangning if windows was closed
if wait_time == 0: # prevent from hanging if windows was closed
while True:
ret = cv2.waitKey(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment