Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
f242eba4
Unverified
Commit
f242eba4
authored
Dec 09, 2022
by
SkyTNT
Committed by
GitHub
Dec 09, 2022
Browse files
Fix lpw stable diffusion pipeline compatibility (#1622)
parent
3faf204c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
254 additions
and
119 deletions
+254
-119
examples/community/lpw_stable_diffusion.py
examples/community/lpw_stable_diffusion.py
+120
-53
examples/community/lpw_stable_diffusion_onnx.py
examples/community/lpw_stable_diffusion_onnx.py
+134
-66
No files found.
examples/community/lpw_stable_diffusion.py
View file @
f242eba4
...
...
@@ -5,14 +5,37 @@ from typing import Callable, List, Optional, Union
import
numpy
as
np
import
torch
import
diffusers
import
PIL
from
diffusers
import
SchedulerMixin
,
StableDiffusionPipeline
from
diffusers.models
import
AutoencoderKL
,
UNet2DConditionModel
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
,
StableDiffusionSafetyChecker
from
diffusers.utils
import
PIL_INTERPOLATION
,
deprecate
,
logging
from
diffusers.utils
import
deprecate
,
logging
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
try
:
from
diffusers.utils
import
PIL_INTERPOLATION
except
ImportError
:
if
version
.
parse
(
version
.
parse
(
PIL
.
__version__
).
base_version
)
>=
version
.
parse
(
"9.1.0"
):
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bilinear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
Resampling
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
Resampling
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
Resampling
.
NEAREST
,
}
else
:
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
LINEAR
,
"bilinear"
:
PIL
.
Image
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
NEAREST
,
}
# ------------------------------------------------------------------------------
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
re_attention
=
re
.
compile
(
...
...
@@ -404,27 +427,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def
__init__
(
self
,
vae
:
AutoencoderKL
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
UNet2DConditionModel
,
scheduler
:
SchedulerMixin
,
safety_checker
:
StableDiffusionSafetyChecker
,
feature_extractor
:
CLIPFeatureExtractor
,
requires_safety_checker
:
bool
=
True
,
):
super
().
__init__
(
vae
=
vae
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
requires_safety_checker
=
requires_safety_checker
,
)
if
version
.
parse
(
version
.
parse
(
diffusers
.
__version__
).
base_version
)
>=
version
.
parse
(
"0.9.0"
):
def
__init__
(
self
,
vae
:
AutoencoderKL
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
UNet2DConditionModel
,
scheduler
:
SchedulerMixin
,
safety_checker
:
StableDiffusionSafetyChecker
,
feature_extractor
:
CLIPFeatureExtractor
,
requires_safety_checker
:
bool
=
True
,
):
super
().
__init__
(
vae
=
vae
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
requires_safety_checker
=
requires_safety_checker
,
)
self
.
__init__additional__
()
else
:
def
__init__
(
self
,
vae
:
AutoencoderKL
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
UNet2DConditionModel
,
scheduler
:
SchedulerMixin
,
safety_checker
:
StableDiffusionSafetyChecker
,
feature_extractor
:
CLIPFeatureExtractor
,
):
super
().
__init__
(
vae
=
vae
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
)
self
.
__init__additional__
()
def
__init__additional__
(
self
):
if
not
hasattr
(
self
,
"vae_scale_factor"
):
setattr
(
self
,
"vae_scale_factor"
,
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
))
@
property
def
_execution_device
(
self
):
r
"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if
self
.
device
!=
torch
.
device
(
"meta"
)
or
not
hasattr
(
self
.
unet
,
"_hf_hook"
):
return
self
.
device
for
module
in
self
.
unet
.
modules
():
if
(
hasattr
(
module
,
"_hf_hook"
)
and
hasattr
(
module
.
_hf_hook
,
"execution_device"
)
and
module
.
_hf_hook
.
execution_device
is
not
None
):
return
torch
.
device
(
module
.
_hf_hook
.
execution_device
)
return
self
.
device
def
_encode_prompt
(
self
,
...
...
@@ -752,37 +823,33 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
extra_step_kwargs
=
self
.
prepare_extra_step_kwargs
(
generator
,
eta
)
# 8. Denoising loop
num_warmup_steps
=
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
for
i
,
t
in
enumerate
(
timesteps
):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
# perform guidance
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
if
mask
is
not
None
:
# masking
init_latents_proper
=
self
.
scheduler
.
add_noise
(
init_latents_orig
,
noise
,
torch
.
tensor
([
t
]))
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
# call the callback, if provided
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
i
%
callback_steps
==
0
:
if
callback
is
not
None
:
callback
(
i
,
t
,
latents
)
if
is_cancelled_callback
is
not
None
and
is_cancelled_callback
():
return
None
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
)):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
# perform guidance
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
if
mask
is
not
None
:
# masking
init_latents_proper
=
self
.
scheduler
.
add_noise
(
init_latents_orig
,
noise
,
torch
.
tensor
([
t
]))
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
# call the callback, if provided
if
i
%
callback_steps
==
0
:
if
callback
is
not
None
:
callback
(
i
,
t
,
latents
)
if
is_cancelled_callback
is
not
None
and
is_cancelled_callback
():
return
None
# 9. Post-processing
image
=
self
.
decode_latents
(
latents
)
...
...
examples/community/lpw_stable_diffusion_onnx.py
View file @
f242eba4
...
...
@@ -5,14 +5,55 @@ from typing import Callable, List, Optional, Union
import
numpy
as
np
import
torch
import
diffusers
import
PIL
from
diffusers
import
OnnxStableDiffusionPipeline
,
SchedulerMixin
from
diffusers.onnx_utils
import
ORT_TO_NP_TYPE
,
OnnxRuntimeModel
from
diffusers.onnx_utils
import
OnnxRuntimeModel
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
from
diffusers.utils
import
PIL_INTERPOLATION
,
deprecate
,
logging
from
diffusers.utils
import
deprecate
,
logging
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
try
:
from
diffusers.onnx_utils
import
ORT_TO_NP_TYPE
except
ImportError
:
ORT_TO_NP_TYPE
=
{
"tensor(bool)"
:
np
.
bool_
,
"tensor(int8)"
:
np
.
int8
,
"tensor(uint8)"
:
np
.
uint8
,
"tensor(int16)"
:
np
.
int16
,
"tensor(uint16)"
:
np
.
uint16
,
"tensor(int32)"
:
np
.
int32
,
"tensor(uint32)"
:
np
.
uint32
,
"tensor(int64)"
:
np
.
int64
,
"tensor(uint64)"
:
np
.
uint64
,
"tensor(float16)"
:
np
.
float16
,
"tensor(float)"
:
np
.
float32
,
"tensor(double)"
:
np
.
float64
,
}
try
:
from
diffusers.utils
import
PIL_INTERPOLATION
except
ImportError
:
if
version
.
parse
(
version
.
parse
(
PIL
.
__version__
).
base_version
)
>=
version
.
parse
(
"9.1.0"
):
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bilinear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
Resampling
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
Resampling
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
Resampling
.
NEAREST
,
}
else
:
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
LINEAR
,
"bilinear"
:
PIL
.
Image
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
NEAREST
,
}
# ------------------------------------------------------------------------------
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
re_attention
=
re
.
compile
(
...
...
@@ -390,30 +431,59 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
"""
if
version
.
parse
(
version
.
parse
(
diffusers
.
__version__
).
base_version
)
>=
version
.
parse
(
"0.9.0"
):
def
__init__
(
self
,
vae_encoder
:
OnnxRuntimeModel
,
vae_decoder
:
OnnxRuntimeModel
,
text_encoder
:
OnnxRuntimeModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
OnnxRuntimeModel
,
scheduler
:
SchedulerMixin
,
safety_checker
:
OnnxRuntimeModel
,
feature_extractor
:
CLIPFeatureExtractor
,
requires_safety_checker
:
bool
=
True
,
):
super
().
__init__
(
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
requires_safety_checker
=
requires_safety_checker
,
)
self
.
__init__additional__
()
def
__init__
(
self
,
vae_encoder
:
OnnxRuntimeModel
,
vae_decoder
:
OnnxRuntimeModel
,
text_encoder
:
OnnxRuntimeModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
OnnxRuntimeModel
,
scheduler
:
SchedulerMixin
,
safety_checker
:
OnnxRuntimeModel
,
feature_extractor
:
CLIPFeatureExtractor
,
requires_safety_checker
:
bool
=
True
,
):
super
().
__init__
(
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
requires_safety_checker
=
requires_safety_checker
,
)
else
:
def
__init__
(
self
,
vae_encoder
:
OnnxRuntimeModel
,
vae_decoder
:
OnnxRuntimeModel
,
text_encoder
:
OnnxRuntimeModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
OnnxRuntimeModel
,
scheduler
:
SchedulerMixin
,
safety_checker
:
OnnxRuntimeModel
,
feature_extractor
:
CLIPFeatureExtractor
,
):
super
().
__init__
(
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
)
self
.
__init__additional__
()
def
__init__additional__
(
self
):
self
.
unet_in_channels
=
4
self
.
vae_scale_factor
=
8
...
...
@@ -741,49 +811,47 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
extra_step_kwargs
=
self
.
prepare_extra_step_kwargs
(
generator
,
eta
)
# 8. Denoising loop
num_warmup_steps
=
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
for
i
,
t
in
enumerate
(
timesteps
):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
np
.
concatenate
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
torch
.
from_numpy
(
latent_model_input
),
t
)
latent_model_input
=
latent_model_input
.
numpy
()
# predict the noise residual
noise_pred
=
self
.
unet
(
sample
=
latent_model_input
,
timestep
=
np
.
array
([
t
],
dtype
=
timestep_dtype
),
encoder_hidden_states
=
text_embeddings
,
)
noise_pred
=
noise_pred
[
0
]
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
)):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
np
.
concatenate
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
torch
.
from_numpy
(
latent_model_input
),
t
)
latent_model_input
=
latent_model_input
.
numpy
()
# predict the noise residual
noise_pred
=
self
.
unet
(
sample
=
latent_model_input
,
timestep
=
np
.
array
([
t
],
dtype
=
timestep_dtype
),
encoder_hidden_states
=
text_embeddings
,
)
noise_pred
=
noise_pred
[
0
]
# perform guidance
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
np
.
split
(
noise_pred
,
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# perform guidance
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
np
.
split
(
noise_pred
,
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output
=
self
.
scheduler
.
step
(
torch
.
from_numpy
(
noise_pred
),
t
,
torch
.
from_numpy
(
latents
),
**
extra_step_kwargs
)
latents
=
scheduler_output
.
prev_sample
.
numpy
()
if
mask
is
not
None
:
# masking
init_latents_proper
=
self
.
scheduler
.
add_noise
(
torch
.
from_numpy
(
init_latents_orig
),
torch
.
from_numpy
(
noise
),
t
,
).
numpy
()
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
# call the callback, if provided
if
i
%
callback_steps
==
0
:
if
callback
is
not
None
:
callback
(
i
,
t
,
latents
)
if
is_cancelled_callback
is
not
None
and
is_cancelled_callback
():
return
None
# compute the previous noisy sample x_t -> x_t-1
scheduler_output
=
self
.
scheduler
.
step
(
torch
.
from_numpy
(
noise_pred
),
t
,
torch
.
from_numpy
(
latents
),
**
extra_step_kwargs
)
latents
=
scheduler_output
.
prev_sample
.
numpy
()
if
mask
is
not
None
:
# masking
init_latents_proper
=
self
.
scheduler
.
add_noise
(
torch
.
from_numpy
(
init_latents_orig
),
torch
.
from_numpy
(
noise
),
t
,
).
numpy
()
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
i
%
callback_steps
==
0
:
if
callback
is
not
None
:
callback
(
i
,
t
,
latents
)
if
is_cancelled_callback
is
not
None
and
is_cancelled_callback
():
return
None
# 9. Post-processing
image
=
self
.
decode_latents
(
latents
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment