Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
8211b622
Unverified
Commit
8211b622
authored
Sep 23, 2022
by
cloudhan
Committed by
GitHub
Sep 23, 2022
Browse files
Allow passing session_options for ORT backend (#620)
parent
ce31f83d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
4 deletions
+9
-4
src/diffusers/onnx_utils.py
src/diffusers/onnx_utils.py
+7
-4
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+2
-0
No files found.
src/diffusers/onnx_utils.py
View file @
8211b622
...
...
@@ -46,7 +46,7 @@ class OnnxRuntimeModel:
return
self
.
model
.
run
(
None
,
inputs
)
@
staticmethod
def
load_model
(
path
:
Union
[
str
,
Path
],
provider
=
None
):
def
load_model
(
path
:
Union
[
str
,
Path
],
provider
=
None
,
sess_options
=
None
):
"""
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
...
...
@@ -60,7 +60,7 @@ class OnnxRuntimeModel:
logger
.
info
(
"No onnxruntime provider specified, using CPUExecutionProvider"
)
provider
=
"CPUExecutionProvider"
return
ort
.
InferenceSession
(
path
,
providers
=
[
provider
])
return
ort
.
InferenceSession
(
path
,
providers
=
[
provider
]
,
sess_options
=
sess_options
)
def
_save_pretrained
(
self
,
save_directory
:
Union
[
str
,
Path
],
file_name
:
Optional
[
str
]
=
None
,
**
kwargs
):
"""
...
...
@@ -114,6 +114,7 @@ class OnnxRuntimeModel:
cache_dir
:
Optional
[
str
]
=
None
,
file_name
:
Optional
[
str
]
=
None
,
provider
:
Optional
[
str
]
=
None
,
sess_options
:
Optional
[
ort
.
SessionOptions
]
=
None
,
**
kwargs
,
):
"""
...
...
@@ -143,7 +144,9 @@ class OnnxRuntimeModel:
model_file_name
=
file_name
if
file_name
is
not
None
else
ONNX_WEIGHTS_NAME
# load model from local directory
if
os
.
path
.
isdir
(
model_id
):
model
=
OnnxRuntimeModel
.
load_model
(
os
.
path
.
join
(
model_id
,
model_file_name
),
provider
=
provider
)
model
=
OnnxRuntimeModel
.
load_model
(
os
.
path
.
join
(
model_id
,
model_file_name
),
provider
=
provider
,
sess_options
=
sess_options
)
kwargs
[
"model_save_dir"
]
=
Path
(
model_id
)
# load model from hub
else
:
...
...
@@ -158,7 +161,7 @@ class OnnxRuntimeModel:
)
kwargs
[
"model_save_dir"
]
=
Path
(
model_cache_path
).
parent
kwargs
[
"latest_model_name"
]
=
Path
(
model_cache_path
).
name
model
=
OnnxRuntimeModel
.
load_model
(
model_cache_path
,
provider
=
provider
)
model
=
OnnxRuntimeModel
.
load_model
(
model_cache_path
,
provider
=
provider
,
sess_options
=
sess_options
)
return
cls
(
model
=
model
,
**
kwargs
)
@
classmethod
...
...
src/diffusers/pipeline_utils.py
View file @
8211b622
...
...
@@ -282,6 +282,7 @@ class DiffusionPipeline(ConfigMixin):
revision
=
kwargs
.
pop
(
"revision"
,
None
)
torch_dtype
=
kwargs
.
pop
(
"torch_dtype"
,
None
)
provider
=
kwargs
.
pop
(
"provider"
,
None
)
sess_options
=
kwargs
.
pop
(
"sess_options"
,
None
)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
...
...
@@ -398,6 +399,7 @@ class DiffusionPipeline(ConfigMixin):
loading_kwargs
[
"torch_dtype"
]
=
torch_dtype
if
issubclass
(
class_obj
,
diffusers
.
OnnxRuntimeModel
):
loading_kwargs
[
"provider"
]
=
provider
loading_kwargs
[
"sess_options"
]
=
sess_options
# check if the module is in a subdirectory
if
os
.
path
.
isdir
(
os
.
path
.
join
(
cached_folder
,
name
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment