Unverified Commit 4a9ab650 authored by urpetkov-amd's avatar urpetkov-amd Committed by GitHub
Browse files

Fixing missing provider options argument (#11397)



* Fixing missing provider options argument

* Adding if else for provider options

* Apply suggestions from code review
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Apply style fixes

* Update src/diffusers/pipelines/onnx_utils.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/onnx_utils.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

---------
Co-authored-by: default avatarUros Petkovic <urpektov@amd.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 0ac1d5b4
......@@ -75,6 +75,11 @@ class OnnxRuntimeModel:
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
if provider_options is None:
provider_options = []
elif not isinstance(provider_options, list):
provider_options = [provider_options]
return ort.InferenceSession(
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
)
......@@ -174,7 +179,10 @@ class OnnxRuntimeModel:
# load model from local directory
if os.path.isdir(model_id):
model = OnnxRuntimeModel.load_model(
Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options
Path(model_id, model_file_name).as_posix(),
provider=provider,
sess_options=sess_options,
provider_options=kwargs.pop("provider_options"),
)
kwargs["model_save_dir"] = Path(model_id)
# load model from hub
......@@ -190,7 +198,12 @@ class OnnxRuntimeModel:
)
kwargs["model_save_dir"] = Path(model_cache_path).parent
kwargs["latest_model_name"] = Path(model_cache_path).name
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
model = OnnxRuntimeModel.load_model(
model_cache_path,
provider=provider,
sess_options=sess_options,
provider_options=kwargs.pop("provider_options"),
)
return cls(model=model, **kwargs)
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment