Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
86aa747d
Unverified
Commit
86aa747d
authored
Nov 25, 2022
by
Anton Lozhkov
Committed by
GitHub
Nov 25, 2022
Browse files
Fix ONNX conversion and inference (#1416)
parent
d52388f4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
94 deletions
+18
-94
scripts/convert_stable_diffusion_checkpoint_to_onnx.py
scripts/convert_stable_diffusion_checkpoint_to_onnx.py
+4
-1
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
...elines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+5
-35
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
...table_diffusion/pipeline_onnx_stable_diffusion_img2img.py
+2
-22
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
...table_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+7
-36
No files found.
scripts/convert_stable_diffusion_checkpoint_to_onnx.py
View file @
86aa747d
...
@@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
...
@@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
)
)
del
pipeline
.
safety_checker
del
pipeline
.
safety_checker
safety_checker
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"safety_checker"
)
safety_checker
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"safety_checker"
)
feature_extractor
=
pipeline
.
feature_extractor
else
:
else
:
safety_checker
=
None
safety_checker
=
None
feature_extractor
=
None
onnx_pipeline
=
OnnxStableDiffusionPipeline
(
onnx_pipeline
=
OnnxStableDiffusionPipeline
(
vae_encoder
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"vae_encoder"
),
vae_encoder
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"vae_encoder"
),
...
@@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
...
@@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
unet
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"unet"
),
unet
=
OnnxRuntimeModel
.
from_pretrained
(
output_path
/
"unet"
),
scheduler
=
pipeline
.
scheduler
,
scheduler
=
pipeline
.
scheduler
,
safety_checker
=
safety_checker
,
safety_checker
=
safety_checker
,
feature_extractor
=
pipeline
.
feature_extractor
,
feature_extractor
=
feature_extractor
,
requires_safety_checker
=
safety_checker
is
not
None
,
)
)
onnx_pipeline
.
save_pretrained
(
output_path
)
onnx_pipeline
.
save_pretrained
(
output_path
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
View file @
86aa747d
...
@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
...
@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
...configuration_utils
import
FrozenDict
from
...configuration_utils
import
FrozenDict
...
@@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker
:
OnnxRuntimeModel
safety_checker
:
OnnxRuntimeModel
feature_extractor
:
CLIPFeatureExtractor
feature_extractor
:
CLIPFeatureExtractor
_optional_components
=
[
"safety_checker"
,
"feature_extractor"
]
def
__init__
(
def
__init__
(
self
,
self
,
vae_encoder
:
OnnxRuntimeModel
,
vae_encoder
:
OnnxRuntimeModel
,
...
@@ -99,27 +100,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -99,27 +100,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
)
is_unet_version_less_0_9_0
=
hasattr
(
unet
.
config
,
"_diffusers_version"
)
and
version
.
parse
(
version
.
parse
(
unet
.
config
.
_diffusers_version
).
base_version
)
<
version
.
parse
(
"0.9.0.dev0"
)
is_unet_sample_size_less_64
=
hasattr
(
unet
.
config
,
"sample_size"
)
and
unet
.
config
.
sample_size
<
64
if
is_unet_version_less_0_9_0
and
is_unet_sample_size_less_64
:
deprecation_message
=
(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following:
\n
- CompVis/stable-diffusion-v1-4
\n
- CompVis/stable-diffusion-v1-3
\n
-"
" CompVis/stable-diffusion-v1-2
\n
- CompVis/stable-diffusion-v1-1
\n
- runwayml/stable-diffusion-v1-5"
"
\n
- runwayml/stable-diffusion-inpainting
\n
you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate
(
"sample_size<64"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
unet
.
config
)
new_config
[
"sample_size"
]
=
64
unet
.
_internal_dict
=
FrozenDict
(
new_config
)
self
.
register_modules
(
self
.
register_modules
(
vae_encoder
=
vae_encoder
,
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
vae_decoder
=
vae_decoder
,
...
@@ -130,7 +110,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -130,7 +110,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker
=
safety_checker
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
feature_extractor
=
feature_extractor
,
)
)
self
.
vae_scale_factor
=
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
)
self
.
register_to_config
(
requires_safety_checker
=
requires_safety_checker
)
self
.
register_to_config
(
requires_safety_checker
=
requires_safety_checker
)
def
_encode_prompt
(
self
,
prompt
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
):
def
_encode_prompt
(
self
,
prompt
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
):
...
@@ -213,8 +192,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -213,8 +192,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
def
__call__
(
def
__call__
(
self
,
self
,
prompt
:
Union
[
str
,
List
[
str
]],
prompt
:
Union
[
str
,
List
[
str
]],
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
512
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
512
,
num_inference_steps
:
Optional
[
int
]
=
50
,
num_inference_steps
:
Optional
[
int
]
=
50
,
guidance_scale
:
Optional
[
float
]
=
7.5
,
guidance_scale
:
Optional
[
float
]
=
7.5
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
...
@@ -228,10 +207,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -228,10 +207,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
callback_steps
:
Optional
[
int
]
=
1
,
callback_steps
:
Optional
[
int
]
=
1
,
**
kwargs
,
**
kwargs
,
):
):
# 0. Default height and width to unet
height
=
height
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
batch_size
=
1
batch_size
=
1
elif
isinstance
(
prompt
,
list
):
elif
isinstance
(
prompt
,
list
):
...
@@ -264,12 +239,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
...
@@ -264,12 +239,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# get the initial random noise unless the user supplied it
# get the initial random noise unless the user supplied it
latents_dtype
=
text_embeddings
.
dtype
latents_dtype
=
text_embeddings
.
dtype
latents_shape
=
(
latents_shape
=
(
batch_size
*
num_images_per_prompt
,
4
,
height
//
8
,
width
//
8
)
batch_size
*
num_images_per_prompt
,
4
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
,
)
if
latents
is
None
:
if
latents
is
None
:
latents
=
generator
.
randn
(
*
latents_shape
).
astype
(
latents_dtype
)
latents
=
generator
.
randn
(
*
latents_shape
).
astype
(
latents_dtype
)
elif
latents
.
shape
!=
latents_shape
:
elif
latents
.
shape
!=
latents_shape
:
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
View file @
86aa747d
...
@@ -19,7 +19,6 @@ import numpy as np
...
@@ -19,7 +19,6 @@ import numpy as np
import
torch
import
torch
import
PIL
import
PIL
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
...configuration_utils
import
FrozenDict
from
...configuration_utils
import
FrozenDict
...
@@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker
:
OnnxRuntimeModel
safety_checker
:
OnnxRuntimeModel
feature_extractor
:
CLIPFeatureExtractor
feature_extractor
:
CLIPFeatureExtractor
_optional_components
=
[
"safety_checker"
,
"feature_extractor"
]
def
__init__
(
def
__init__
(
self
,
self
,
vae_encoder
:
OnnxRuntimeModel
,
vae_encoder
:
OnnxRuntimeModel
,
...
@@ -135,27 +136,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -135,27 +136,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
)
is_unet_version_less_0_9_0
=
hasattr
(
unet
.
config
,
"_diffusers_version"
)
and
version
.
parse
(
version
.
parse
(
unet
.
config
.
_diffusers_version
).
base_version
)
<
version
.
parse
(
"0.9.0.dev0"
)
is_unet_sample_size_less_64
=
hasattr
(
unet
.
config
,
"sample_size"
)
and
unet
.
config
.
sample_size
<
64
if
is_unet_version_less_0_9_0
and
is_unet_sample_size_less_64
:
deprecation_message
=
(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following:
\n
- CompVis/stable-diffusion-v1-4
\n
- CompVis/stable-diffusion-v1-3
\n
-"
" CompVis/stable-diffusion-v1-2
\n
- CompVis/stable-diffusion-v1-1
\n
- runwayml/stable-diffusion-v1-5"
"
\n
- runwayml/stable-diffusion-inpainting
\n
you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate
(
"sample_size<64"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
unet
.
config
)
new_config
[
"sample_size"
]
=
64
unet
.
_internal_dict
=
FrozenDict
(
new_config
)
self
.
register_modules
(
self
.
register_modules
(
vae_encoder
=
vae_encoder
,
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
vae_decoder
=
vae_decoder
,
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
View file @
86aa747d
...
@@ -19,7 +19,6 @@ import numpy as np
...
@@ -19,7 +19,6 @@ import numpy as np
import
torch
import
torch
import
PIL
import
PIL
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
...configuration_utils
import
FrozenDict
from
...configuration_utils
import
FrozenDict
...
@@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker
:
OnnxRuntimeModel
safety_checker
:
OnnxRuntimeModel
feature_extractor
:
CLIPFeatureExtractor
feature_extractor
:
CLIPFeatureExtractor
_optional_components
=
[
"safety_checker"
,
"feature_extractor"
]
def
__init__
(
def
__init__
(
self
,
self
,
vae_encoder
:
OnnxRuntimeModel
,
vae_encoder
:
OnnxRuntimeModel
,
...
@@ -149,27 +150,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -149,27 +150,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
)
is_unet_version_less_0_9_0
=
hasattr
(
unet
.
config
,
"_diffusers_version"
)
and
version
.
parse
(
version
.
parse
(
unet
.
config
.
_diffusers_version
).
base_version
)
<
version
.
parse
(
"0.9.0.dev0"
)
is_unet_sample_size_less_64
=
hasattr
(
unet
.
config
,
"sample_size"
)
and
unet
.
config
.
sample_size
<
64
if
is_unet_version_less_0_9_0
and
is_unet_sample_size_less_64
:
deprecation_message
=
(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following:
\n
- CompVis/stable-diffusion-v1-4
\n
- CompVis/stable-diffusion-v1-3
\n
-"
" CompVis/stable-diffusion-v1-2
\n
- CompVis/stable-diffusion-v1-1
\n
- runwayml/stable-diffusion-v1-5"
"
\n
- runwayml/stable-diffusion-inpainting
\n
you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate
(
"sample_size<64"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
unet
.
config
)
new_config
[
"sample_size"
]
=
64
unet
.
_internal_dict
=
FrozenDict
(
new_config
)
self
.
register_modules
(
self
.
register_modules
(
vae_encoder
=
vae_encoder
,
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
vae_decoder
=
vae_decoder
,
...
@@ -180,7 +160,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -180,7 +160,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker
=
safety_checker
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
feature_extractor
=
feature_extractor
,
)
)
self
.
vae_scale_factor
=
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
)
self
.
register_to_config
(
requires_safety_checker
=
requires_safety_checker
)
self
.
register_to_config
(
requires_safety_checker
=
requires_safety_checker
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
...
@@ -267,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -267,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt
:
Union
[
str
,
List
[
str
]],
prompt
:
Union
[
str
,
List
[
str
]],
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
mask_image
:
PIL
.
Image
.
Image
,
mask_image
:
PIL
.
Image
.
Image
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
512
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
512
,
num_inference_steps
:
int
=
50
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
float
=
7.5
,
guidance_scale
:
float
=
7.5
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
...
@@ -296,9 +275,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -296,9 +275,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to
self.unet.config.sample_size * self.vae_scale_factor
):
height (`int`, *optional*, defaults to
512
):
The height in pixels of the generated image.
The height in pixels of the generated image.
width (`int`, *optional*, defaults to
self.unet.config.sample_size * self.vae_scale_factor
):
width (`int`, *optional*, defaults to
512
):
The width in pixels of the generated image.
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
...
@@ -343,9 +322,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -343,9 +322,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
(nsfw) content, according to the `safety_checker`.
"""
"""
# 0. Default height and width to unet
height
=
height
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
batch_size
=
1
batch_size
=
1
...
@@ -381,12 +357,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -381,12 +357,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
)
)
num_channels_latents
=
NUM_LATENT_CHANNELS
num_channels_latents
=
NUM_LATENT_CHANNELS
latents_shape
=
(
latents_shape
=
(
batch_size
*
num_images_per_prompt
,
num_channels_latents
,
height
//
8
,
width
//
8
)
batch_size
*
num_images_per_prompt
,
num_channels_latents
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
,
)
latents_dtype
=
text_embeddings
.
dtype
latents_dtype
=
text_embeddings
.
dtype
if
latents
is
None
:
if
latents
is
None
:
latents
=
generator
.
randn
(
*
latents_shape
).
astype
(
latents_dtype
)
latents
=
generator
.
randn
(
*
latents_shape
).
astype
(
latents_dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment