Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
5f826a35
Unverified
Commit
5f826a35
authored
Mar 06, 2023
by
ForserX
Committed by
GitHub
Mar 06, 2023
Browse files
Add custom vae (diffusers type) to onnx converter (#2325)
parent
f7278638
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
126 additions
and
0 deletions
+126
-0
scripts/convert_vae_diff_to_onnx.py
scripts/convert_vae_diff_to_onnx.py
+126
-0
No files found.
scripts/convert_vae_diff_to_onnx.py
0 → 100644
View file @
5f826a35
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
shutil
from
pathlib
import
Path
import
torch
from
torch.onnx
import
export
import
onnx
from
diffusers
import
OnnxRuntimeModel
,
OnnxStableDiffusionPipeline
,
StableDiffusionPipeline
,
AutoencoderKL
from
packaging
import
version
is_torch_less_than_1_11
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
<
version
.
parse
(
"1.11"
)
def
onnx_export
(
model
,
model_args
:
tuple
,
output_path
:
Path
,
ordered_input_names
,
output_names
,
dynamic_axes
,
opset
,
use_external_data_format
=
False
,
):
output_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if
is_torch_less_than_1_11
:
export
(
model
,
model_args
,
f
=
output_path
.
as_posix
(),
input_names
=
ordered_input_names
,
output_names
=
output_names
,
dynamic_axes
=
dynamic_axes
,
do_constant_folding
=
True
,
use_external_data_format
=
use_external_data_format
,
enable_onnx_checker
=
True
,
opset_version
=
opset
,
)
else
:
export
(
model
,
model_args
,
f
=
output_path
.
as_posix
(),
input_names
=
ordered_input_names
,
output_names
=
output_names
,
dynamic_axes
=
dynamic_axes
,
do_constant_folding
=
True
,
opset_version
=
opset
,
)
@
torch
.
no_grad
()
def
convert_models
(
model_path
:
str
,
output_path
:
str
,
opset
:
int
,
fp16
:
bool
=
False
):
dtype
=
torch
.
float16
if
fp16
else
torch
.
float32
if
fp16
and
torch
.
cuda
.
is_available
():
device
=
"cuda"
elif
fp16
and
not
torch
.
cuda
.
is_available
():
raise
ValueError
(
"`float16` model export is only supported on GPUs with CUDA"
)
else
:
device
=
"cpu"
output_path
=
Path
(
output_path
)
# VAE DECODER
vae_decoder
=
AutoencoderKL
.
from_pretrained
(
model_path
+
"/vae"
)
vae_latent_channels
=
vae_decoder
.
config
.
latent_channels
vae_out_channels
=
vae_decoder
.
config
.
out_channels
# forward only through the decoder part
vae_decoder
.
forward
=
vae_decoder
.
decode
onnx_export
(
vae_decoder
,
model_args
=
(
torch
.
randn
(
1
,
vae_latent_channels
,
25
,
25
).
to
(
device
=
device
,
dtype
=
dtype
),
False
,
),
output_path
=
output_path
/
"vae_decoder"
/
"model.onnx"
,
ordered_input_names
=
[
"latent_sample"
,
"return_dict"
],
output_names
=
[
"sample"
],
dynamic_axes
=
{
"latent_sample"
:
{
0
:
"batch"
,
1
:
"channels"
,
2
:
"height"
,
3
:
"width"
},
},
opset
=
opset
,
)
del
vae_decoder
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub)."
,
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--opset"
,
default
=
14
,
type
=
int
,
help
=
"The version of the ONNX operator set to use."
,
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Export the models in `float16` mode"
)
args
=
parser
.
parse_args
()
print
(
args
.
output_path
)
convert_models
(
args
.
model_path
,
args
.
output_path
,
args
.
opset
,
args
.
fp16
)
print
(
"SD: Done: ONNX"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment