"csrc/ktransformers_ext/cpu_backend/backend.cpp" did not exist on "18c42e67df28bb6c7f5dc847595637327919e5ea"
Unverified Commit 158c5c4d authored by hlky's avatar hlky Committed by GitHub
Browse files

Add provider_options to OnnxRuntimeModel (#10661)

parent 41571773
...@@ -61,7 +61,7 @@ class OnnxRuntimeModel: ...@@ -61,7 +61,7 @@ class OnnxRuntimeModel:
return self.model.run(None, inputs) return self.model.run(None, inputs)
@staticmethod @staticmethod
def load_model(path: Union[str, Path], provider=None, sess_options=None): def load_model(path: Union[str, Path], provider=None, sess_options=None, provider_options=None):
""" """
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
...@@ -75,7 +75,9 @@ class OnnxRuntimeModel: ...@@ -75,7 +75,9 @@ class OnnxRuntimeModel:
logger.info("No onnxruntime provider specified, using CPUExecutionProvider") logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider" provider = "CPUExecutionProvider"
return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) return ort.InferenceSession(
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
)
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment