Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f242eba4
"tests/vscode:/vscode.git/clone" did not exist on "5fbb33e73dd1b05426882829875e143068a84482"
Unverified
Commit
f242eba4
authored
Dec 09, 2022
by
SkyTNT
Committed by
GitHub
Dec 09, 2022
Browse files
Fix lpw stable diffusion pipeline compatibility (#1622)
parent
3faf204c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
254 additions
and
119 deletions
+254
-119
examples/community/lpw_stable_diffusion.py
examples/community/lpw_stable_diffusion.py
+120
-53
examples/community/lpw_stable_diffusion_onnx.py
examples/community/lpw_stable_diffusion_onnx.py
+134
-66
No files found.
examples/community/lpw_stable_diffusion.py
View file @
f242eba4
...
...
@@ -5,14 +5,37 @@ from typing import Callable, List, Optional, Union
import
numpy
as
np
import
torch
import
diffusers
import
PIL
from
diffusers
import
SchedulerMixin
,
StableDiffusionPipeline
from
diffusers.models
import
AutoencoderKL
,
UNet2DConditionModel
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
,
StableDiffusionSafetyChecker
from
diffusers.utils
import
PIL_INTERPOLATION
,
deprecate
,
logging
from
diffusers.utils
import
deprecate
,
logging
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
try
:
from
diffusers.utils
import
PIL_INTERPOLATION
except
ImportError
:
if
version
.
parse
(
version
.
parse
(
PIL
.
__version__
).
base_version
)
>=
version
.
parse
(
"9.1.0"
):
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bilinear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
Resampling
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
Resampling
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
Resampling
.
NEAREST
,
}
else
:
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
LINEAR
,
"bilinear"
:
PIL
.
Image
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
NEAREST
,
}
# ------------------------------------------------------------------------------
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
re_attention
=
re
.
compile
(
...
...
@@ -404,6 +427,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
if
version
.
parse
(
version
.
parse
(
diffusers
.
__version__
).
base_version
)
>=
version
.
parse
(
"0.9.0"
):
def
__init__
(
self
,
vae
:
AutoencoderKL
,
...
...
@@ -425,6 +450,52 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
feature_extractor
=
feature_extractor
,
requires_safety_checker
=
requires_safety_checker
,
)
self
.
__init__additional__
()
else
:
def
__init__
(
self
,
vae
:
AutoencoderKL
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
UNet2DConditionModel
,
scheduler
:
SchedulerMixin
,
safety_checker
:
StableDiffusionSafetyChecker
,
feature_extractor
:
CLIPFeatureExtractor
,
):
super
().
__init__
(
vae
=
vae
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
)
self
.
__init__additional__
()
def
__init__additional__
(
self
):
if
not
hasattr
(
self
,
"vae_scale_factor"
):
setattr
(
self
,
"vae_scale_factor"
,
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
))
@
property
def
_execution_device
(
self
):
r
"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if
self
.
device
!=
torch
.
device
(
"meta"
)
or
not
hasattr
(
self
.
unet
,
"_hf_hook"
):
return
self
.
device
for
module
in
self
.
unet
.
modules
():
if
(
hasattr
(
module
,
"_hf_hook"
)
and
hasattr
(
module
.
_hf_hook
,
"execution_device"
)
and
module
.
_hf_hook
.
execution_device
is
not
None
):
return
torch
.
device
(
module
.
_hf_hook
.
execution_device
)
return
self
.
device
def
_encode_prompt
(
self
,
...
...
@@ -752,9 +823,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
extra_step_kwargs
=
self
.
prepare_extra_step_kwargs
(
generator
,
eta
)
# 8. Denoising loop
num_warmup_steps
=
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
for
i
,
t
in
enumerate
(
timesteps
):
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
)):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
...
...
@@ -776,8 +845,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
# call the callback, if provided
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
i
%
callback_steps
==
0
:
if
callback
is
not
None
:
callback
(
i
,
t
,
latents
)
...
...
examples/community/lpw_stable_diffusion_onnx.py
View file @
f242eba4
...
...
@@ -5,14 +5,55 @@ from typing import Callable, List, Optional, Union
import
numpy
as
np
import
torch
import
diffusers
import
PIL
from
diffusers
import
OnnxStableDiffusionPipeline
,
SchedulerMixin
from
diffusers.onnx_utils
import
ORT_TO_NP_TYPE
,
OnnxRuntimeModel
from
diffusers.onnx_utils
import
OnnxRuntimeModel
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
from
diffusers.utils
import
PIL_INTERPOLATION
,
deprecate
,
logging
from
diffusers.utils
import
deprecate
,
logging
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
try
:
from
diffusers.onnx_utils
import
ORT_TO_NP_TYPE
except
ImportError
:
ORT_TO_NP_TYPE
=
{
"tensor(bool)"
:
np
.
bool_
,
"tensor(int8)"
:
np
.
int8
,
"tensor(uint8)"
:
np
.
uint8
,
"tensor(int16)"
:
np
.
int16
,
"tensor(uint16)"
:
np
.
uint16
,
"tensor(int32)"
:
np
.
int32
,
"tensor(uint32)"
:
np
.
uint32
,
"tensor(int64)"
:
np
.
int64
,
"tensor(uint64)"
:
np
.
uint64
,
"tensor(float16)"
:
np
.
float16
,
"tensor(float)"
:
np
.
float32
,
"tensor(double)"
:
np
.
float64
,
}
try
:
from
diffusers.utils
import
PIL_INTERPOLATION
except
ImportError
:
if
version
.
parse
(
version
.
parse
(
PIL
.
__version__
).
base_version
)
>=
version
.
parse
(
"9.1.0"
):
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bilinear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
Resampling
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
Resampling
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
Resampling
.
NEAREST
,
}
else
:
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
LINEAR
,
"bilinear"
:
PIL
.
Image
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
NEAREST
,
}
# ------------------------------------------------------------------------------
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
re_attention
=
re
.
compile
(
...
...
@@ -390,6 +431,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
"""
if
version
.
parse
(
version
.
parse
(
diffusers
.
__version__
).
base_version
)
>=
version
.
parse
(
"0.9.0"
):
def
__init__
(
self
,
...
...
@@ -414,6 +456,34 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
feature_extractor
=
feature_extractor
,
requires_safety_checker
=
requires_safety_checker
,
)
self
.
__init__additional__
()
else
:
def
__init__
(
self
,
vae_encoder
:
OnnxRuntimeModel
,
vae_decoder
:
OnnxRuntimeModel
,
text_encoder
:
OnnxRuntimeModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
OnnxRuntimeModel
,
scheduler
:
SchedulerMixin
,
safety_checker
:
OnnxRuntimeModel
,
feature_extractor
:
CLIPFeatureExtractor
,
):
super
().
__init__
(
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
)
self
.
__init__additional__
()
def
__init__additional__
(
self
):
self
.
unet_in_channels
=
4
self
.
vae_scale_factor
=
8
...
...
@@ -741,9 +811,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
extra_step_kwargs
=
self
.
prepare_extra_step_kwargs
(
generator
,
eta
)
# 8. Denoising loop
num_warmup_steps
=
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
for
i
,
t
in
enumerate
(
timesteps
):
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
)):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
np
.
concatenate
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
torch
.
from_numpy
(
latent_model_input
),
t
)
...
...
@@ -777,13 +845,13 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
).
numpy
()
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
# call the callback, if provided
if
i
%
callback_steps
==
0
:
if
callback
is
not
None
:
callback
(
i
,
t
,
latents
)
if
is_cancelled_callback
is
not
None
and
is_cancelled_callback
():
return
None
# 9. Post-processing
image
=
self
.
decode_latents
(
latents
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment