Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
86aa747d
Unverified
Commit
86aa747d
authored
Nov 25, 2022
by
Anton Lozhkov
Committed by
GitHub
Nov 25, 2022
Browse files
Fix ONNX conversion and inference (#1416)
parent
d52388f4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
94 deletions
+18
-94
scripts/convert_stable_diffusion_checkpoint_to_onnx.py
scripts/convert_stable_diffusion_checkpoint_to_onnx.py
+4
-1
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
...elines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+5
-35
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
...table_diffusion/pipeline_onnx_stable_diffusion_img2img.py
+2
-22
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
...table_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+7
-36
No files found.
scripts/convert_stable_diffusion_checkpoint_to_onnx.py
View file @
86aa747d
...
@@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
...
@@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
)
)
del
pipeline
.
safety_checker
del
pipeline
.
safety_checker
safety_checker
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"safety_checker"
)
safety_checker
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"safety_checker"
)
feature_extractor
=
pipeline
.
feature_extractor
else
:
else
:
safety_checker
=
None
safety_checker
=
None
feature_extractor
=
None
onnx_pipeline
=
OnnxStableDiffusionPipeline
(
onnx_pipeline
=
OnnxStableDiffusionPipeline
(
vae_encoder
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"vae_encoder"
),
vae_encoder
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"vae_encoder"
),
...
@@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
...
@@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
unet
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"unet"
),
unet
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"unet"
),
scheduler
=
pipeline
.
scheduler
,
scheduler
=
pipeline
.
scheduler
,
safety_checker
=
safety_checker
,
safety_checker
=
safety_checker
,
feature_extractor
=
pipeline
.
feature_extractor
,
feature_extractor
=
feature_extractor
,
requires_safety_checker
=
safety_checker
is
not
None
,
)
)
onnx_pipeline
.
save_pretrained
(
output_path
)
onnx_pipeline
.
save_pretrained
(
output_path
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
View file @
86aa747d
...
@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
...
@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
...configuration_utils
import
FrozenDict
from
...configuration_utils
import
FrozenDict
...
@@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker
:
OnnxRuntimeModel
safety_checker
:
OnnxRuntimeModel
feature_extractor
:
CLIPFeatureExtractor
feature_extractor
:
CLIPFeatureExtractor
_optional_components
=
[
"safety_checker"
,
"feature_extractor"
]
def
__init__
(
def
__init__
(
self
,
self
,
vae_encoder
:
OnnxRuntimeModel
,
vae_encoder
:
OnnxRuntimeModel
,
...
@@ -99,27 +100,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -99,27 +100,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
)
is_unet_version_less_0_9_0
=
hasattr
(
unet
.
config
,
"_diffusers_version"
)
and
version
.
parse
(
version
.
parse
(
unet
.
config
.
_diffusers_version
).
base_version
)
<
version
.
parse
(
"0.9.0.dev0"
)
is_unet_sample_size_less_64
=
hasattr
(
unet
.
config
,
"sample_size"
)
and
unet
.
config
.
sample_size
<
64
if
is_unet_version_less_0_9_0
and
is_unet_sample_size_less_64
:
deprecation_message
=
(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following:
\n
- CompVis/stable-diffusion-v1-4
\n
- CompVis/stable-diffusion-v1-3
\n
-"
" CompVis/stable-diffusion-v1-2
\n
- CompVis/stable-diffusion-v1-1
\n
- runwayml/stable-diffusion-v1-5"
"
\n
- runwayml/stable-diffusion-inpainting
\n
you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate
(
"sample_size<64"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
unet
.
config
)
new_config
[
"sample_size"
]
=
64
unet
.
_internal_dict
=
FrozenDict
(
new_config
)
self
.
register_modules
(
self
.
register_modules
(
vae_encoder
=
vae_encoder
,
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
vae_decoder
=
vae_decoder
,
...
@@ -130,7 +110,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -130,7 +110,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker
=
safety_checker
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
feature_extractor
=
feature_extractor
,
)
)
self
.
vae_scale_factor
=
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
)
self
.
register_to_config
(
requires_safety_checker
=
requires_safety_checker
)
self
.
register_to_config
(
requires_safety_checker
=
requires_safety_checker
)
def
_encode_prompt
(
self
,
prompt
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
):
def
_encode_prompt
(
self
,
prompt
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
):
...
@@ -213,8 +192,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -213,8 +192,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
def
__call__
(
def
__call__
(
self
,
self
,
prompt
:
Union
[
str
,
List
[
str
]],
prompt
:
Union
[
str
,
List
[
str
]],
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
512
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
512
,
num_inference_steps
:
Optional
[
int
]
=
50
,
num_inference_steps
:
Optional
[
int
]
=
50
,
guidance_scale
:
Optional
[
float
]
=
7.5
,
guidance_scale
:
Optional
[
float
]
=
7.5
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
...
@@ -228,10 +207,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -228,10 +207,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
callback_steps
:
Optional
[
int
]
=
1
,
callback_steps
:
Optional
[
int
]
=
1
,
**
kwargs
,
**
kwargs
,
):
):
# 0. Default height and width to unet
height
=
height
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
batch_size
=
1
batch_size
=
1
elif
isinstance
(
prompt
,
list
):
elif
isinstance
(
prompt
,
list
):
...
@@ -264,12 +239,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -264,12 +239,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# get the initial random noise unless the user supplied it
# get the initial random noise unless the user supplied it
latents_dtype
=
text_embeddings
.
dtype
latents_dtype
=
text_embeddings
.
dtype
latents_shape
=
(
latents_shape
=
(
batch_size
*
num_images_per_prompt
,
4
,
height
//
8
,
width
//
8
)
batch_size
*
num_images_per_prompt
,
4
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
,
)
if
latents
is
None
:
if
latents
is
None
:
latents
=
generator
.
randn
(
*
latents_shape
).
astype
(
latents_dtype
)
latents
=
generator
.
randn
(
*
latents_shape
).
astype
(
latents_dtype
)
elif
latents
.
shape
!=
latents_shape
:
elif
latents
.
shape
!=
latents_shape
:
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
View file @
86aa747d
...
@@ -19,7 +19,6 @@ import numpy as np
...
@@ -19,7 +19,6 @@ import numpy as np
import
torch
import
torch
import
PIL
import
PIL
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
...configuration_utils
import
FrozenDict
from
...configuration_utils
import
FrozenDict
...
@@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker
:
OnnxRuntimeModel
safety_checker
:
OnnxRuntimeModel
feature_extractor
:
CLIPFeatureExtractor
feature_extractor
:
CLIPFeatureExtractor
_optional_components
=
[
"safety_checker"
,
"feature_extractor"
]
def
__init__
(
def
__init__
(
self
,
self
,
vae_encoder
:
OnnxRuntimeModel
,
vae_encoder
:
OnnxRuntimeModel
,
...
@@ -135,27 +136,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -135,27 +136,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
)
is_unet_version_less_0_9_0
=
hasattr
(
unet
.
config
,
"_diffusers_version"
)
and
version
.
parse
(
version
.
parse
(
unet
.
config
.
_diffusers_version
).
base_version
)
<
version
.
parse
(
"0.9.0.dev0"
)
is_unet_sample_size_less_64
=
hasattr
(
unet
.
config
,
"sample_size"
)
and
unet
.
config
.
sample_size
<
64
if
is_unet_version_less_0_9_0
and
is_unet_sample_size_less_64
:
deprecation_message
=
(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following:
\n
- CompVis/stable-diffusion-v1-4
\n
- CompVis/stable-diffusion-v1-3
\n
-"
" CompVis/stable-diffusion-v1-2
\n
- CompVis/stable-diffusion-v1-1
\n
- runwayml/stable-diffusion-v1-5"
"
\n
- runwayml/stable-diffusion-inpainting
\n
you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate
(
"sample_size<64"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
unet
.
config
)
new_config
[
"sample_size"
]
=
64
unet
.
_internal_dict
=
FrozenDict
(
new_config
)
self
.
register_modules
(
self
.
register_modules
(
vae_encoder
=
vae_encoder
,
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
vae_decoder
=
vae_decoder
,
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
View file @
86aa747d
...
@@ -19,7 +19,6 @@ import numpy as np
...
@@ -19,7 +19,6 @@ import numpy as np
import
torch
import
torch
import
PIL
import
PIL
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
...configuration_utils
import
FrozenDict
from
...configuration_utils
import
FrozenDict
...
@@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker
:
OnnxRuntimeModel
safety_checker
:
OnnxRuntimeModel
feature_extractor
:
CLIPFeatureExtractor
feature_extractor
:
CLIPFeatureExtractor
_optional_components
=
[
"safety_checker"
,
"feature_extractor"
]
def
__init__
(
def
__init__
(
self
,
self
,
vae_encoder
:
OnnxRuntimeModel
,
vae_encoder
:
OnnxRuntimeModel
,
...
@@ -149,27 +150,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -149,27 +150,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
)
is_unet_version_less_0_9_0
=
hasattr
(
unet
.
config
,
"_diffusers_version"
)
and
version
.
parse
(
version
.
parse
(
unet
.
config
.
_diffusers_version
).
base_version
)
<
version
.
parse
(
"0.9.0.dev0"
)
is_unet_sample_size_less_64
=
hasattr
(
unet
.
config
,
"sample_size"
)
and
unet
.
config
.
sample_size
<
64
if
is_unet_version_less_0_9_0
and
is_unet_sample_size_less_64
:
deprecation_message
=
(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following:
\n
- CompVis/stable-diffusion-v1-4
\n
- CompVis/stable-diffusion-v1-3
\n
-"
" CompVis/stable-diffusion-v1-2
\n
- CompVis/stable-diffusion-v1-1
\n
- runwayml/stable-diffusion-v1-5"
"
\n
- runwayml/stable-diffusion-inpainting
\n
you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate
(
"sample_size<64"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
unet
.
config
)
new_config
[
"sample_size"
]
=
64
unet
.
_internal_dict
=
FrozenDict
(
new_config
)
self
.
register_modules
(
self
.
register_modules
(
vae_encoder
=
vae_encoder
,
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
vae_decoder
=
vae_decoder
,
...
@@ -180,7 +160,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -180,7 +160,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker
=
safety_checker
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
feature_extractor
=
feature_extractor
,
)
)
self
.
vae_scale_factor
=
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
)
self
.
register_to_config
(
requires_safety_checker
=
requires_safety_checker
)
self
.
register_to_config
(
requires_safety_checker
=
requires_safety_checker
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
...
@@ -267,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -267,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt
:
Union
[
str
,
List
[
str
]],
prompt
:
Union
[
str
,
List
[
str
]],
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
mask_image
:
PIL
.
Image
.
Image
,
mask_image
:
PIL
.
Image
.
Image
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
512
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
512
,
num_inference_steps
:
int
=
50
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
float
=
7.5
,
guidance_scale
:
float
=
7.5
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
...
@@ -296,9 +275,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -296,9 +275,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to
self.unet.config.sample_size * self.vae_scale_factor
):
height (`int`, *optional*, defaults to
512
):
The height in pixels of the generated image.
The height in pixels of the generated image.
width (`int`, *optional*, defaults to
self.unet.config.sample_size * self.vae_scale_factor
):
width (`int`, *optional*, defaults to
512
):
The width in pixels of the generated image.
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
...
@@ -343,9 +322,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -343,9 +322,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
(nsfw) content, according to the `safety_checker`.
"""
"""
# 0. Default height and width to unet
height
=
height
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
batch_size
=
1
batch_size
=
1
...
@@ -381,12 +357,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -381,12 +357,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
)
)
num_channels_latents
=
NUM_LATENT_CHANNELS
num_channels_latents
=
NUM_LATENT_CHANNELS
latents_shape
=
(
latents_shape
=
(
batch_size
*
num_images_per_prompt
,
num_channels_latents
,
height
//
8
,
width
//
8
)
batch_size
*
num_images_per_prompt
,
num_channels_latents
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
,
)
latents_dtype
=
text_embeddings
.
dtype
latents_dtype
=
text_embeddings
.
dtype
if
latents
is
None
:
if
latents
is
None
:
latents
=
generator
.
randn
(
*
latents_shape
).
astype
(
latents_dtype
)
latents
=
generator
.
randn
(
*
latents_shape
).
astype
(
latents_dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment