Unverified Commit 158c5c4d authored by hlky's avatar hlky Committed by GitHub
Browse files

Add provider_options to OnnxRuntimeModel (#10661)

parent 41571773
......@@ -61,7 +61,7 @@ class OnnxRuntimeModel:
return self.model.run(None, inputs)
@staticmethod
def load_model(path: Union[str, Path], provider=None, sess_options=None):
def load_model(path: Union[str, Path], provider=None, sess_options=None, provider_options=None):
"""
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
......@@ -75,7 +75,9 @@ class OnnxRuntimeModel:
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
return ort.InferenceSession(
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
)
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment