Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5f826a35
Unverified
Commit
5f826a35
authored
Mar 06, 2023
by
ForserX
Committed by
GitHub
Mar 06, 2023
Browse files
Add custom vae (diffusers type) to onnx converter (#2325)
parent
f7278638
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
126 additions
and
0 deletions
+126
-0
scripts/convert_vae_diff_to_onnx.py
scripts/convert_vae_diff_to_onnx.py
+126
-0
No files found.
scripts/convert_vae_diff_to_onnx.py
0 → 100644
View file @
5f826a35
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
shutil
from
pathlib
import
Path
import
torch
from
torch.onnx
import
export
import
onnx
from
diffusers
import
OnnxRuntimeModel
,
OnnxStableDiffusionPipeline
,
StableDiffusionPipeline
,
AutoencoderKL
from
packaging
import
version
is_torch_less_than_1_11
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
<
version
.
parse
(
"1.11"
)
def
onnx_export
(
model
,
model_args
:
tuple
,
output_path
:
Path
,
ordered_input_names
,
output_names
,
dynamic_axes
,
opset
,
use_external_data_format
=
False
,
):
output_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if
is_torch_less_than_1_11
:
export
(
model
,
model_args
,
f
=
output_path
.
as_posix
(),
input_names
=
ordered_input_names
,
output_names
=
output_names
,
dynamic_axes
=
dynamic_axes
,
do_constant_folding
=
True
,
use_external_data_format
=
use_external_data_format
,
enable_onnx_checker
=
True
,
opset_version
=
opset
,
)
else
:
export
(
model
,
model_args
,
f
=
output_path
.
as_posix
(),
input_names
=
ordered_input_names
,
output_names
=
output_names
,
dynamic_axes
=
dynamic_axes
,
do_constant_folding
=
True
,
opset_version
=
opset
,
)
@
torch
.
no_grad
()
def
convert_models
(
model_path
:
str
,
output_path
:
str
,
opset
:
int
,
fp16
:
bool
=
False
):
dtype
=
torch
.
float16
if
fp16
else
torch
.
float32
if
fp16
and
torch
.
cuda
.
is_available
():
device
=
"cuda"
elif
fp16
and
not
torch
.
cuda
.
is_available
():
raise
ValueError
(
"`float16` model export is only supported on GPUs with CUDA"
)
else
:
device
=
"cpu"
output_path
=
Path
(
output_path
)
# VAE DECODER
vae_decoder
=
AutoencoderKL
.
from_pretrained
(
model_path
+
"/vae"
)
vae_latent_channels
=
vae_decoder
.
config
.
latent_channels
vae_out_channels
=
vae_decoder
.
config
.
out_channels
# forward only through the decoder part
vae_decoder
.
forward
=
vae_decoder
.
decode
onnx_export
(
vae_decoder
,
model_args
=
(
torch
.
randn
(
1
,
vae_latent_channels
,
25
,
25
).
to
(
device
=
device
,
dtype
=
dtype
),
False
,
),
output_path
=
output_path
/
"vae_decoder"
/
"model.onnx"
,
ordered_input_names
=
[
"latent_sample"
,
"return_dict"
],
output_names
=
[
"sample"
],
dynamic_axes
=
{
"latent_sample"
:
{
0
:
"batch"
,
1
:
"channels"
,
2
:
"height"
,
3
:
"width"
},
},
opset
=
opset
,
)
del
vae_decoder
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub)."
,
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--opset"
,
default
=
14
,
type
=
int
,
help
=
"The version of the ONNX operator set to use."
,
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Export the models in `float16` mode"
)
args
=
parser
.
parse_args
()
print
(
args
.
output_path
)
convert_models
(
args
.
model_path
,
args
.
output_path
,
args
.
opset
,
args
.
fp16
)
print
(
"SD: Done: ONNX"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment