Unverified Commit 596386af authored by SWHL's avatar SWHL Committed by GitHub
Browse files

Update utils.py

parent c3c9b6b0
...@@ -300,12 +300,12 @@ class OrtInferSession(): ...@@ -300,12 +300,12 @@ class OrtInferSession():
RuntimeWarning) RuntimeWarning)
def __call__(self, def __call__(self,
input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray: input_content: List[np.ndarray]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content)) input_dict = dict(zip(self.get_input_names(), input_content))
try: try:
return self.session.run(None, input_dict) return self.session.run(None, input_dict)
except Exception as e: except Exception as e:
raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e raise ONNXRuntimeError('ONNXRuntime inference failed.') from e
def get_input_names(self, ): def get_input_names(self, ):
return [v.name for v in self.session.get_inputs()] return [v.name for v in self.session.get_inputs()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment