Unverified Commit 8211b622 authored by cloudhan's avatar cloudhan Committed by GitHub
Browse files

Allow passing session_options for ORT backend (#620)

parent ce31f83d
...@@ -46,7 +46,7 @@ class OnnxRuntimeModel: ...@@ -46,7 +46,7 @@ class OnnxRuntimeModel:
return self.model.run(None, inputs) return self.model.run(None, inputs)
@staticmethod @staticmethod
def load_model(path: Union[str, Path], provider=None): def load_model(path: Union[str, Path], provider=None, sess_options=None):
""" """
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
...@@ -60,7 +60,7 @@ class OnnxRuntimeModel: ...@@ -60,7 +60,7 @@ class OnnxRuntimeModel:
logger.info("No onnxruntime provider specified, using CPUExecutionProvider") logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider" provider = "CPUExecutionProvider"
return ort.InferenceSession(path, providers=[provider]) return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
""" """
...@@ -114,6 +114,7 @@ class OnnxRuntimeModel: ...@@ -114,6 +114,7 @@ class OnnxRuntimeModel:
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
file_name: Optional[str] = None, file_name: Optional[str] = None,
provider: Optional[str] = None, provider: Optional[str] = None,
sess_options: Optional[ort.SessionOptions] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -143,7 +144,9 @@ class OnnxRuntimeModel: ...@@ -143,7 +144,9 @@ class OnnxRuntimeModel:
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
# load model from local directory # load model from local directory
if os.path.isdir(model_id): if os.path.isdir(model_id):
model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider) model = OnnxRuntimeModel.load_model(
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
)
kwargs["model_save_dir"] = Path(model_id) kwargs["model_save_dir"] = Path(model_id)
# load model from hub # load model from hub
else: else:
...@@ -158,7 +161,7 @@ class OnnxRuntimeModel: ...@@ -158,7 +161,7 @@ class OnnxRuntimeModel:
) )
kwargs["model_save_dir"] = Path(model_cache_path).parent kwargs["model_save_dir"] = Path(model_cache_path).parent
kwargs["latest_model_name"] = Path(model_cache_path).name kwargs["latest_model_name"] = Path(model_cache_path).name
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider) model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
return cls(model=model, **kwargs) return cls(model=model, **kwargs)
@classmethod @classmethod
......
...@@ -282,6 +282,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -282,6 +282,7 @@ class DiffusionPipeline(ConfigMixin):
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
provider = kwargs.pop("provider", None) provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
...@@ -398,6 +399,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -398,6 +399,7 @@ class DiffusionPipeline(ConfigMixin):
loading_kwargs["torch_dtype"] = torch_dtype loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers.OnnxRuntimeModel): if issubclass(class_obj, diffusers.OnnxRuntimeModel):
loading_kwargs["provider"] = provider loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
# check if the module is in a subdirectory # check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)): if os.path.isdir(os.path.join(cached_folder, name)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment