Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4eb9ad0d
Unverified
Commit
4eb9ad0d
authored
Dec 07, 2022
by
SkyTNT
Committed by
GitHub
Dec 07, 2022
Browse files
[Community Pipeline] fix lpw_stable_diffusion (#1570)
* fix lpw_stable_diffusion * rollback preprocess_mask resample
parent
896c98a2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
563 additions
and
537 deletions
+563
-537
examples/community/lpw_stable_diffusion.py
examples/community/lpw_stable_diffusion.py
+275
-313
examples/community/lpw_stable_diffusion_onnx.py
examples/community/lpw_stable_diffusion_onnx.py
+288
-224
No files found.
examples/community/lpw_stable_diffusion.py
View file @
4eb9ad0d
...
@@ -6,38 +6,13 @@ import numpy as np
...
@@ -6,38 +6,13 @@ import numpy as np
import
torch
import
torch
import
PIL
import
PIL
from
diffusers
.configuration_utils
import
FrozenDict
from
diffusers
import
SchedulerMixin
,
StableDiffusionPipeline
from
diffusers.models
import
AutoencoderKL
,
UNet2DConditionModel
from
diffusers.models
import
AutoencoderKL
,
UNet2DConditionModel
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
,
StableDiffusionSafetyChecker
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
from
diffusers.utils
import
PIL_INTERPOLATION
,
deprecate
,
logging
from
diffusers.pipelines.stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
diffusers.schedulers
import
DDIMScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
from
diffusers.utils
import
deprecate
,
is_accelerate_available
,
logging
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
if
version
.
parse
(
version
.
parse
(
PIL
.
__version__
).
base_version
)
>=
version
.
parse
(
"9.1.0"
):
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bilinear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
Resampling
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
Resampling
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
Resampling
.
NEAREST
,
}
else
:
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
LINEAR
,
"bilinear"
:
PIL
.
Image
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
NEAREST
,
}
# ------------------------------------------------------------------------------
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
re_attention
=
re
.
compile
(
re_attention
=
re
.
compile
(
...
@@ -146,7 +121,7 @@ def parse_prompt_attention(text):
...
@@ -146,7 +121,7 @@ def parse_prompt_attention(text):
return
res
return
res
def
get_prompts_with_weights
(
pipe
:
DiffusionPipeline
,
prompt
:
List
[
str
],
max_length
:
int
):
def
get_prompts_with_weights
(
pipe
:
Stable
DiffusionPipeline
,
prompt
:
List
[
str
],
max_length
:
int
):
r
"""
r
"""
Tokenize a list of prompts and return its tokens with weights of each token.
Tokenize a list of prompts and return its tokens with weights of each token.
...
@@ -207,7 +182,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
...
@@ -207,7 +182,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
def
get_unweighted_text_embeddings
(
def
get_unweighted_text_embeddings
(
pipe
:
DiffusionPipeline
,
pipe
:
Stable
DiffusionPipeline
,
text_input
:
torch
.
Tensor
,
text_input
:
torch
.
Tensor
,
chunk_length
:
int
,
chunk_length
:
int
,
no_boseos_middle
:
Optional
[
bool
]
=
True
,
no_boseos_middle
:
Optional
[
bool
]
=
True
,
...
@@ -247,10 +222,10 @@ def get_unweighted_text_embeddings(
...
@@ -247,10 +222,10 @@ def get_unweighted_text_embeddings(
def
get_weighted_text_embeddings
(
def
get_weighted_text_embeddings
(
pipe
:
DiffusionPipeline
,
pipe
:
Stable
DiffusionPipeline
,
prompt
:
Union
[
str
,
List
[
str
]],
prompt
:
Union
[
str
,
List
[
str
]],
uncond_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
uncond_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
max_embeddings_multiples
:
Optional
[
int
]
=
1
,
max_embeddings_multiples
:
Optional
[
int
]
=
3
,
no_boseos_middle
:
Optional
[
bool
]
=
False
,
no_boseos_middle
:
Optional
[
bool
]
=
False
,
skip_parsing
:
Optional
[
bool
]
=
False
,
skip_parsing
:
Optional
[
bool
]
=
False
,
skip_weighting
:
Optional
[
bool
]
=
False
,
skip_weighting
:
Optional
[
bool
]
=
False
,
...
@@ -264,14 +239,14 @@ def get_weighted_text_embeddings(
...
@@ -264,14 +239,14 @@ def get_weighted_text_embeddings(
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
Args:
pipe (`DiffusionPipeline`):
pipe (`
Stable
DiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
uncond_prompt (`str` or `List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `
1
`):
max_embeddings_multiples (`int`, *optional*, defaults to `
3
`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
...
@@ -387,11 +362,11 @@ def preprocess_image(image):
...
@@ -387,11 +362,11 @@ def preprocess_image(image):
return
2.0
*
image
-
1.0
return
2.0
*
image
-
1.0
def
preprocess_mask
(
mask
):
def
preprocess_mask
(
mask
,
scale_factor
=
8
):
mask
=
mask
.
convert
(
"L"
)
mask
=
mask
.
convert
(
"L"
)
w
,
h
=
mask
.
size
w
,
h
=
mask
.
size
w
,
h
=
map
(
lambda
x
:
x
-
x
%
32
,
(
w
,
h
))
# resize to integer multiple of 32
w
,
h
=
map
(
lambda
x
:
x
-
x
%
32
,
(
w
,
h
))
# resize to integer multiple of 32
mask
=
mask
.
resize
((
w
//
8
,
h
//
8
),
resample
=
PIL_INTERPOLATION
[
"nearest"
])
mask
=
mask
.
resize
((
w
//
scale_factor
,
h
//
scale_factor
),
resample
=
PIL_INTERPOLATION
[
"nearest"
])
mask
=
np
.
array
(
mask
).
astype
(
np
.
float32
)
/
255.0
mask
=
np
.
array
(
mask
).
astype
(
np
.
float32
)
/
255.0
mask
=
np
.
tile
(
mask
,
(
4
,
1
,
1
))
mask
=
np
.
tile
(
mask
,
(
4
,
1
,
1
))
mask
=
mask
[
None
].
transpose
(
0
,
1
,
2
,
3
)
# what does this step do?
mask
=
mask
[
None
].
transpose
(
0
,
1
,
2
,
3
)
# what does this step do?
...
@@ -400,7 +375,7 @@ def preprocess_mask(mask):
...
@@ -400,7 +375,7 @@ def preprocess_mask(mask):
return
mask
return
mask
class
StableDiffusionLongPromptWeightingPipeline
(
DiffusionPipeline
):
class
StableDiffusionLongPromptWeightingPipeline
(
Stable
DiffusionPipeline
):
r
"""
r
"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
weighting in prompt.
weighting in prompt.
...
@@ -435,50 +410,12 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -435,50 +410,12 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
text_encoder
:
CLIPTextModel
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
CLIPTokenizer
,
tokenizer
:
CLIPTokenizer
,
unet
:
UNet2DConditionModel
,
unet
:
UNet2DConditionModel
,
scheduler
:
Union
[
DDIMScheduler
,
PNDMScheduler
,
LMSDiscreteScheduler
]
,
scheduler
:
SchedulerMixin
,
safety_checker
:
StableDiffusionSafetyChecker
,
safety_checker
:
StableDiffusionSafetyChecker
,
feature_extractor
:
CLIPFeatureExtractor
,
feature_extractor
:
CLIPFeatureExtractor
,
requires_safety_checker
:
bool
=
True
,
):
):
super
().
__init__
()
super
().
__init__
(
if
hasattr
(
scheduler
.
config
,
"steps_offset"
)
and
scheduler
.
config
.
steps_offset
!=
1
:
deprecation_message
=
(
f
"The configuration file of this scheduler:
{
scheduler
}
is outdated. `steps_offset`"
f
" should be set to 1 instead of
{
scheduler
.
config
.
steps_offset
}
. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate
(
"steps_offset!=1"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
scheduler
.
config
)
new_config
[
"steps_offset"
]
=
1
scheduler
.
_internal_dict
=
FrozenDict
(
new_config
)
if
hasattr
(
scheduler
.
config
,
"clip_sample"
)
and
scheduler
.
config
.
clip_sample
is
True
:
deprecation_message
=
(
f
"The configuration file of this scheduler:
{
scheduler
}
has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate
(
"clip_sample not set"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
scheduler
.
config
)
new_config
[
"clip_sample"
]
=
False
scheduler
.
_internal_dict
=
FrozenDict
(
new_config
)
if
safety_checker
is
None
:
logger
.
warning
(
f
"You have disabled the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self
.
register_modules
(
vae
=
vae
,
vae
=
vae
,
text_encoder
=
text_encoder
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -486,51 +423,171 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -486,51 +423,171 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
scheduler
=
scheduler
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
feature_extractor
=
feature_extractor
,
requires_safety_checker
=
requires_safety_checker
,
)
)
def
enable_attention_slicing
(
self
,
slice_size
:
Optional
[
Union
[
str
,
int
]]
=
"auto"
):
def
_encode_prompt
(
self
,
prompt
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
,
max_embeddings_multiples
,
):
r
"""
r
"""
Enable sliced attention computation.
Encodes the prompt into text encoder hidden states.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
prompt (`str` or `list(int)`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
prompt to be encoded
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
device: (`torch.device`):
`attention_head_dim` must be a multiple of `slice_size`.
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
"""
"""
if
slice_size
==
"auto"
:
batch_size
=
len
(
prompt
)
if
isinstance
(
prompt
,
list
)
else
1
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size
=
self
.
unet
.
config
.
attention_head_dim
//
2
self
.
unet
.
set_attention_slice
(
slice_size
)
def
disable_attention_slicing
(
self
):
if
negative_prompt
is
None
:
r
"""
negative_prompt
=
[
""
]
*
batch_size
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
elif
isinstance
(
negative_prompt
,
str
):
back to computing attention in one step.
negative_prompt
=
[
negative_prompt
]
*
batch_size
"""
if
batch_size
!=
len
(
negative_prompt
):
# set slice_size = `None` to disable `attention slicing`
raise
ValueError
(
self
.
enable_attention_slicing
(
None
)
f
"`negative_prompt`:
{
negative_prompt
}
has batch size
{
len
(
negative_prompt
)
}
, but `prompt`:"
f
"
{
prompt
}
has batch size
{
batch_size
}
. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
def
enable_sequential_cpu_offload
(
self
):
text_embeddings
,
uncond_embeddings
=
get_weighted_text_embeddings
(
r
"""
pipe
=
self
,
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
prompt
=
prompt
,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
uncond_prompt
=
negative_prompt
if
do_classifier_free_guidance
else
None
,
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
max_embeddings_multiples
=
max_embeddings_multiples
,
"""
)
if
is_accelerate_available
():
bs_embed
,
seq_len
,
_
=
text_embeddings
.
shape
from
accelerate
import
cpu_offload
text_embeddings
=
text_embeddings
.
repeat
(
1
,
num_images_per_prompt
,
1
)
text_embeddings
=
text_embeddings
.
view
(
bs_embed
*
num_images_per_prompt
,
seq_len
,
-
1
)
if
do_classifier_free_guidance
:
bs_embed
,
seq_len
,
_
=
uncond_embeddings
.
shape
uncond_embeddings
=
uncond_embeddings
.
repeat
(
1
,
num_images_per_prompt
,
1
)
uncond_embeddings
=
uncond_embeddings
.
view
(
bs_embed
*
num_images_per_prompt
,
seq_len
,
-
1
)
text_embeddings
=
torch
.
cat
([
uncond_embeddings
,
text_embeddings
])
return
text_embeddings
def
check_inputs
(
self
,
prompt
,
height
,
width
,
strength
,
callback_steps
):
if
not
isinstance
(
prompt
,
str
)
and
not
isinstance
(
prompt
,
list
):
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
if
strength
<
0
or
strength
>
1
:
raise
ValueError
(
f
"The value of strength should in [0.0, 1.0] but is
{
strength
}
"
)
if
height
%
8
!=
0
or
width
%
8
!=
0
:
raise
ValueError
(
f
"`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
."
)
if
(
callback_steps
is
None
)
or
(
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
)
):
raise
ValueError
(
f
"`callback_steps` has to be a positive integer but is
{
callback_steps
}
of type"
f
"
{
type
(
callback_steps
)
}
."
)
def
get_timesteps
(
self
,
num_inference_steps
,
strength
,
device
,
is_text2img
):
if
is_text2img
:
return
self
.
scheduler
.
timesteps
.
to
(
device
),
num_inference_steps
else
:
# get the original timestep using init_timestep
offset
=
self
.
scheduler
.
config
.
get
(
"steps_offset"
,
0
)
init_timestep
=
int
(
num_inference_steps
*
strength
)
+
offset
init_timestep
=
min
(
init_timestep
,
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
+
offset
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:].
to
(
device
)
return
timesteps
,
num_inference_steps
-
t_start
def
run_safety_checker
(
self
,
image
,
device
,
dtype
):
if
self
.
safety_checker
is
not
None
:
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
device
)
image
,
has_nsfw_concept
=
self
.
safety_checker
(
images
=
image
,
clip_input
=
safety_checker_input
.
pixel_values
.
to
(
dtype
)
)
else
:
else
:
raise
ImportError
(
"Please install accelerate via `pip install accelerate`"
)
has_nsfw_concept
=
None
return
image
,
has_nsfw_concept
def
decode_latents
(
self
,
latents
):
latents
=
1
/
0.18215
*
latents
image
=
self
.
vae
.
decode
(
latents
).
sample
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
return
image
def
prepare_extra_step_kwargs
(
self
,
generator
,
eta
):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
device
=
self
.
device
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
extra_step_kwargs
=
{}
if
accepts_eta
:
extra_step_kwargs
[
"eta"
]
=
eta
# check if the scheduler accepts generator
accepts_generator
=
"generator"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
if
accepts_generator
:
extra_step_kwargs
[
"generator"
]
=
generator
return
extra_step_kwargs
def
prepare_latents
(
self
,
image
,
timestep
,
batch_size
,
height
,
width
,
dtype
,
device
,
generator
,
latents
=
None
):
if
image
is
None
:
shape
=
(
batch_size
,
self
.
unet
.
in_channels
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
,
)
if
latents
is
None
:
if
device
.
type
==
"mps"
:
# randn does not work reproducibly on mps
latents
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
dtype
).
to
(
device
)
else
:
latents
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
else
:
if
latents
.
shape
!=
shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
shape
}
"
)
latents
=
latents
.
to
(
device
)
# scale the initial noise by the standard deviation required by the scheduler
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
return
latents
,
None
,
None
else
:
init_latent_dist
=
self
.
vae
.
encode
(
image
).
latent_dist
init_latents
=
init_latent_dist
.
sample
(
generator
=
generator
)
init_latents
=
0.18215
*
init_latents
init_latents
=
torch
.
cat
([
init_latents
]
*
batch_size
,
dim
=
0
)
init_latents_orig
=
init_latents
shape
=
init_latents
.
shape
for
cpu_offloaded_model
in
[
self
.
unet
,
self
.
text_encoder
,
self
.
vae
,
self
.
safety_checker
]:
# add noise to latents using the timesteps
if
cpu_offloaded_model
is
not
None
:
if
device
.
type
==
"mps"
:
cpu_offload
(
cpu_offloaded_model
,
device
)
noise
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
dtype
).
to
(
device
)
else
:
noise
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
latents
=
self
.
scheduler
.
add_noise
(
init_latents
,
noise
,
timestep
)
return
latents
,
init_latents_orig
,
noise
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
def
__call__
(
...
@@ -634,221 +691,111 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -634,221 +691,111 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
init_image
=
deprecate
(
"init_image"
,
"0.12.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.12.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
image
=
init_image
or
image
if
isinstance
(
prompt
,
str
):
# 0. Default height and width to unet
batch_size
=
1
height
=
height
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
prompt
=
[
prompt
]
width
=
width
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
elif
isinstance
(
prompt
,
list
):
batch_size
=
len
(
prompt
)
else
:
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
if
strength
<
0
or
strength
>
1
:
# 1. Check inputs. Raise error if not correct
raise
ValueError
(
f
"The value of strength should in [0.0, 1.0] but is
{
strength
}
"
)
self
.
check_inputs
(
prompt
,
height
,
width
,
strength
,
callback_steps
)
if
height
%
8
!=
0
or
width
%
8
!=
0
:
raise
ValueError
(
f
"`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
."
)
if
(
callback_steps
is
None
)
or
(
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
)
):
raise
ValueError
(
f
"`callback_steps` has to be a positive integer but is
{
callback_steps
}
of type"
f
"
{
type
(
callback_steps
)
}
."
)
# get prompt text embeddings
# 2. Define call parameters
batch_size
=
1
if
isinstance
(
prompt
,
str
)
else
len
(
prompt
)
device
=
self
.
_execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance
=
guidance_scale
>
1.0
do_classifier_free_guidance
=
guidance_scale
>
1.0
# get unconditional embeddings for classifier free guidance
if
negative_prompt
is
None
:
negative_prompt
=
[
""
]
*
batch_size
elif
isinstance
(
negative_prompt
,
str
):
negative_prompt
=
[
negative_prompt
]
*
batch_size
if
batch_size
!=
len
(
negative_prompt
):
raise
ValueError
(
f
"`negative_prompt`:
{
negative_prompt
}
has batch size
{
len
(
negative_prompt
)
}
, but `prompt`:"
f
"
{
prompt
}
has batch size
{
batch_size
}
. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
text_embeddings
,
uncond_embeddings
=
get_weighted_text_embeddings
(
# 3. Encode input prompt
pipe
=
self
,
text_embeddings
=
self
.
_encode_prompt
(
prompt
=
prompt
,
prompt
,
uncond_prompt
=
negative_prompt
if
do_classifier_free_guidance
else
None
,
device
,
max_embeddings_multiples
=
max_embeddings_multiples
,
num_images_per_prompt
,
**
kwargs
,
do_classifier_free_guidance
,
negative_prompt
,
max_embeddings_multiples
,
)
)
bs_embed
,
seq_len
,
_
=
text_embeddings
.
shape
dtype
=
text_embeddings
.
dtype
text_embeddings
=
text_embeddings
.
repeat
(
1
,
num_images_per_prompt
,
1
)
text_embeddings
=
text_embeddings
.
view
(
bs_embed
*
num_images_per_prompt
,
seq_len
,
-
1
)
# 4. Preprocess image and mask
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
if
do_classifier_free_guidance
:
image
=
preprocess_image
(
image
)
bs_embed
,
seq_len
,
_
=
uncond_embeddings
.
shape
if
image
is
not
None
:
uncond_embeddings
=
uncond_embeddings
.
repeat
(
1
,
num_images_per_prompt
,
1
)
image
=
image
.
to
(
device
=
self
.
device
,
dtype
=
dtype
)
uncond_embeddings
=
uncond_embeddings
.
view
(
bs_embed
*
num_images_per_prompt
,
seq_len
,
-
1
)
if
isinstance
(
mask_image
,
PIL
.
Image
.
Image
):
text_embeddings
=
torch
.
cat
([
uncond_embeddings
,
text_embeddings
])
mask_image
=
preprocess_mask
(
mask_image
,
self
.
vae_scale_factor
)
if
mask_image
is
not
None
:
# set timesteps
mask
=
mask_image
.
to
(
device
=
self
.
device
,
dtype
=
dtype
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
mask
=
torch
.
cat
([
mask
]
*
batch_size
*
num_images_per_prompt
)
latents_dtype
=
text_embeddings
.
dtype
init_latents_orig
=
None
mask
=
None
noise
=
None
if
image
is
None
:
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape
=
(
batch_size
*
num_images_per_prompt
,
self
.
unet
.
in_channels
,
height
//
8
,
width
//
8
,
)
if
latents
is
None
:
if
self
.
device
.
type
==
"mps"
:
# randn does not exist on mps
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
latents_dtype
,
).
to
(
self
.
device
)
else
:
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
,
)
else
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
latents
=
latents
.
to
(
self
.
device
)
timesteps
=
self
.
scheduler
.
timesteps
.
to
(
self
.
device
)
# scale the initial noise by the standard deviation required by the scheduler
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
else
:
else
:
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
mask
=
None
image
=
preprocess_image
(
image
)
# encode the init image into latents and scale the latents
# 5. set timesteps
image
=
image
.
to
(
device
=
self
.
device
,
dtype
=
latents_dtype
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
device
)
init_latent_dist
=
self
.
vae
.
encode
(
image
).
latent_dist
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
num_inference_steps
,
strength
,
device
,
image
is
None
)
init_latents
=
init_latent_dist
.
sample
(
generator
=
generator
)
latent_timestep
=
timesteps
[:
1
].
repeat
(
batch_size
*
num_images_per_prompt
)
init_latents
=
0.18215
*
init_latents
init_latents
=
torch
.
cat
([
init_latents
]
*
batch_size
*
num_images_per_prompt
,
dim
=
0
)
# 6. Prepare latent variables
init_latents_orig
=
init_latents
latents
,
init_latents_orig
,
noise
=
self
.
prepare_latents
(
image
,
# preprocess mask
latent_timestep
,
if
mask_image
is
not
None
:
batch_size
*
num_images_per_prompt
,
if
isinstance
(
mask_image
,
PIL
.
Image
.
Image
):
height
,
mask_image
=
preprocess_mask
(
mask_image
)
width
,
mask_image
=
mask_image
.
to
(
device
=
self
.
device
,
dtype
=
latents_dtype
)
dtype
,
mask
=
torch
.
cat
([
mask_image
]
*
batch_size
*
num_images_per_prompt
)
device
,
generator
,
# check sizes
latents
,
if
not
mask
.
shape
==
init_latents
.
shape
:
)
raise
ValueError
(
"The mask and image should be the same size!"
)
# get the original timestep using init_timestep
offset
=
self
.
scheduler
.
config
.
get
(
"steps_offset"
,
0
)
init_timestep
=
int
(
num_inference_steps
*
strength
)
+
offset
init_timestep
=
min
(
init_timestep
,
num_inference_steps
)
timesteps
=
self
.
scheduler
.
timesteps
[
-
init_timestep
]
timesteps
=
torch
.
tensor
([
timesteps
]
*
batch_size
*
num_images_per_prompt
,
device
=
self
.
device
)
# add noise to latents using the timesteps
if
self
.
device
.
type
==
"mps"
:
# randn does not exist on mps
noise
=
torch
.
randn
(
init_latents
.
shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
latents_dtype
,
).
to
(
self
.
device
)
else
:
noise
=
torch
.
randn
(
init_latents
.
shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
,
)
latents
=
self
.
scheduler
.
add_noise
(
init_latents
,
noise
,
timesteps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
+
offset
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:].
to
(
self
.
device
)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
extra_step_kwargs
=
{}
if
accepts_eta
:
extra_step_kwargs
[
"eta"
]
=
eta
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
)):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
# perform guidance
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
if
mask
is
not
None
:
# masking
init_latents_proper
=
self
.
scheduler
.
add_noise
(
init_latents_orig
,
noise
,
torch
.
tensor
([
t
]))
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
# call the callback, if provided
if
i
%
callback_steps
==
0
:
if
callback
is
not
None
:
callback
(
i
,
t
,
latents
)
if
is_cancelled_callback
is
not
None
and
is_cancelled_callback
():
return
None
latents
=
1
/
0.18215
*
latents
image
=
self
.
vae
.
decode
(
latents
).
sample
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
if
self
.
safety_checker
is
not
None
:
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
self
.
device
)
image
,
has_nsfw_concept
=
self
.
safety_checker
(
images
=
image
,
clip_input
=
safety_checker_input
.
pixel_values
.
to
(
text_embeddings
.
dtype
),
)
else
:
has_nsfw_concept
=
None
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs
=
self
.
prepare_extra_step_kwargs
(
generator
,
eta
)
# 8. Denoising loop
num_warmup_steps
=
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
for
i
,
t
in
enumerate
(
timesteps
):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
# perform guidance
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
if
mask
is
not
None
:
# masking
init_latents_proper
=
self
.
scheduler
.
add_noise
(
init_latents_orig
,
noise
,
torch
.
tensor
([
t
]))
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
# call the callback, if provided
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
i
%
callback_steps
==
0
:
if
callback
is
not
None
:
callback
(
i
,
t
,
latents
)
if
is_cancelled_callback
is
not
None
and
is_cancelled_callback
():
return
None
# 9. Post-processing
image
=
self
.
decode_latents
(
latents
)
# 10. Run safety checker
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
text_embeddings
.
dtype
)
# 11. Convert to PIL
if
output_type
==
"pil"
:
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
image
=
self
.
numpy_to_pil
(
image
)
if
not
return_dict
:
if
not
return_dict
:
return
(
image
,
has_nsfw_concept
)
return
image
,
has_nsfw_concept
return
StableDiffusionPipelineOutput
(
images
=
image
,
nsfw_content_detected
=
has_nsfw_concept
)
return
StableDiffusionPipelineOutput
(
images
=
image
,
nsfw_content_detected
=
has_nsfw_concept
)
...
@@ -868,6 +815,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -868,6 +815,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
is_cancelled_callback
:
Optional
[
Callable
[[],
bool
]]
=
None
,
callback_steps
:
Optional
[
int
]
=
1
,
callback_steps
:
Optional
[
int
]
=
1
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -915,6 +863,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -915,6 +863,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
callback (`Callable`, *optional*):
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
called at every step.
...
@@ -940,6 +891,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -940,6 +891,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type
=
output_type
,
output_type
=
output_type
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
callback
=
callback
,
callback
=
callback
,
is_cancelled_callback
=
is_cancelled_callback
,
callback_steps
=
callback_steps
,
callback_steps
=
callback_steps
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -959,6 +911,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -959,6 +911,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
is_cancelled_callback
:
Optional
[
Callable
[[],
bool
]]
=
None
,
callback_steps
:
Optional
[
int
]
=
1
,
callback_steps
:
Optional
[
int
]
=
1
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -1007,6 +960,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -1007,6 +960,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
callback (`Callable`, *optional*):
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
called at every step.
...
@@ -1031,6 +987,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -1031,6 +987,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type
=
output_type
,
output_type
=
output_type
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
callback
=
callback
,
callback
=
callback
,
is_cancelled_callback
=
is_cancelled_callback
,
callback_steps
=
callback_steps
,
callback_steps
=
callback_steps
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -1051,6 +1008,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -1051,6 +1008,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
is_cancelled_callback
:
Optional
[
Callable
[[],
bool
]]
=
None
,
callback_steps
:
Optional
[
int
]
=
1
,
callback_steps
:
Optional
[
int
]
=
1
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -1103,6 +1061,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -1103,6 +1061,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
callback (`Callable`, *optional*):
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
called at every step.
...
@@ -1128,6 +1089,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -1128,6 +1089,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type
=
output_type
,
output_type
=
output_type
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
callback
=
callback
,
callback
=
callback
,
is_cancelled_callback
=
is_cancelled_callback
,
callback_steps
=
callback_steps
,
callback_steps
=
callback_steps
,
**
kwargs
,
**
kwargs
,
)
)
examples/community/lpw_stable_diffusion_onnx.py
View file @
4eb9ad0d
...
@@ -6,35 +6,13 @@ import numpy as np
...
@@ -6,35 +6,13 @@ import numpy as np
import
torch
import
torch
import
PIL
import
PIL
from
diffusers
.onnx_utils
import
Onnx
RuntimeModel
from
diffusers
import
Onnx
StableDiffusionPipeline
,
SchedulerMixin
from
diffusers.
pipeline
_utils
import
DiffusionPipeline
from
diffusers.
onnx
_utils
import
ORT_TO_NP_TYPE
,
OnnxRuntimeModel
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
from
diffusers.schedulers
import
DDIMScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
from
diffusers.utils
import
PIL_INTERPOLATION
,
deprecate
,
logging
from
diffusers.utils
import
deprecate
,
logging
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from
packaging
import
version
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
if
version
.
parse
(
version
.
parse
(
PIL
.
__version__
).
base_version
)
>=
version
.
parse
(
"9.1.0"
):
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bilinear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
Resampling
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
Resampling
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
Resampling
.
NEAREST
,
}
else
:
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
LINEAR
,
"bilinear"
:
PIL
.
Image
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
NEAREST
,
}
# ------------------------------------------------------------------------------
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
re_attention
=
re
.
compile
(
re_attention
=
re
.
compile
(
...
@@ -262,7 +240,7 @@ def get_weighted_text_embeddings(
...
@@ -262,7 +240,7 @@ def get_weighted_text_embeddings(
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
Args:
pipe (`DiffusionPipeline`):
pipe (`
OnnxStable
DiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
The prompt or prompts to guide the image generation.
...
@@ -392,11 +370,11 @@ def preprocess_image(image):
...
@@ -392,11 +370,11 @@ def preprocess_image(image):
return
2.0
*
image
-
1.0
return
2.0
*
image
-
1.0
def
preprocess_mask
(
mask
):
def
preprocess_mask
(
mask
,
scale_factor
=
8
):
mask
=
mask
.
convert
(
"L"
)
mask
=
mask
.
convert
(
"L"
)
w
,
h
=
mask
.
size
w
,
h
=
mask
.
size
w
,
h
=
map
(
lambda
x
:
x
-
x
%
32
,
(
w
,
h
))
# resize to integer multiple of 32
w
,
h
=
map
(
lambda
x
:
x
-
x
%
32
,
(
w
,
h
))
# resize to integer multiple of 32
mask
=
mask
.
resize
((
w
//
8
,
h
//
8
),
resample
=
PIL_INTERPOLATION
[
"nearest"
])
mask
=
mask
.
resize
((
w
//
scale_factor
,
h
//
scale_factor
),
resample
=
PIL_INTERPOLATION
[
"nearest"
])
mask
=
np
.
array
(
mask
).
astype
(
np
.
float32
)
/
255.0
mask
=
np
.
array
(
mask
).
astype
(
np
.
float32
)
/
255.0
mask
=
np
.
tile
(
mask
,
(
4
,
1
,
1
))
mask
=
np
.
tile
(
mask
,
(
4
,
1
,
1
))
mask
=
mask
[
None
].
transpose
(
0
,
1
,
2
,
3
)
# what does this step do?
mask
=
mask
[
None
].
transpose
(
0
,
1
,
2
,
3
)
# what does this step do?
...
@@ -404,7 +382,7 @@ def preprocess_mask(mask):
...
@@ -404,7 +382,7 @@ def preprocess_mask(mask):
return
mask
return
mask
class
OnnxStableDiffusionLongPromptWeightingPipeline
(
DiffusionPipeline
):
class
OnnxStableDiffusionLongPromptWeightingPipeline
(
OnnxStable
DiffusionPipeline
):
r
"""
r
"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
weighting in prompt.
weighting in prompt.
...
@@ -420,12 +398,12 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -420,12 +398,12 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
text_encoder
:
OnnxRuntimeModel
,
text_encoder
:
OnnxRuntimeModel
,
tokenizer
:
CLIPTokenizer
,
tokenizer
:
CLIPTokenizer
,
unet
:
OnnxRuntimeModel
,
unet
:
OnnxRuntimeModel
,
scheduler
:
Union
[
DDIMScheduler
,
PNDMScheduler
,
LMSDiscreteScheduler
]
,
scheduler
:
SchedulerMixin
,
safety_checker
:
OnnxRuntimeModel
,
safety_checker
:
OnnxRuntimeModel
,
feature_extractor
:
CLIPFeatureExtractor
,
feature_extractor
:
CLIPFeatureExtractor
,
requires_safety_checker
:
bool
=
True
,
):
):
super
().
__init__
()
super
().
__init__
(
self
.
register_modules
(
vae_encoder
=
vae_encoder
,
vae_encoder
=
vae_encoder
,
vae_decoder
=
vae_decoder
,
vae_decoder
=
vae_decoder
,
text_encoder
=
text_encoder
,
text_encoder
=
text_encoder
,
...
@@ -434,8 +412,171 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -434,8 +412,171 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
scheduler
=
scheduler
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
feature_extractor
=
feature_extractor
,
requires_safety_checker
=
requires_safety_checker
,
)
self
.
unet_in_channels
=
4
self
.
vae_scale_factor
=
8
def
_encode_prompt
(
self
,
prompt
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
,
max_embeddings_multiples
,
):
r
"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
"""
batch_size
=
len
(
prompt
)
if
isinstance
(
prompt
,
list
)
else
1
if
negative_prompt
is
None
:
negative_prompt
=
[
""
]
*
batch_size
elif
isinstance
(
negative_prompt
,
str
):
negative_prompt
=
[
negative_prompt
]
*
batch_size
if
batch_size
!=
len
(
negative_prompt
):
raise
ValueError
(
f
"`negative_prompt`:
{
negative_prompt
}
has batch size
{
len
(
negative_prompt
)
}
, but `prompt`:"
f
"
{
prompt
}
has batch size
{
batch_size
}
. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
text_embeddings
,
uncond_embeddings
=
get_weighted_text_embeddings
(
pipe
=
self
,
prompt
=
prompt
,
uncond_prompt
=
negative_prompt
if
do_classifier_free_guidance
else
None
,
max_embeddings_multiples
=
max_embeddings_multiples
,
)
)
text_embeddings
=
text_embeddings
.
repeat
(
num_images_per_prompt
,
0
)
if
do_classifier_free_guidance
:
uncond_embeddings
=
uncond_embeddings
.
repeat
(
num_images_per_prompt
,
0
)
text_embeddings
=
np
.
concatenate
([
uncond_embeddings
,
text_embeddings
])
return
text_embeddings
def
check_inputs
(
self
,
prompt
,
height
,
width
,
strength
,
callback_steps
):
if
not
isinstance
(
prompt
,
str
)
and
not
isinstance
(
prompt
,
list
):
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
if
strength
<
0
or
strength
>
1
:
raise
ValueError
(
f
"The value of strength should in [0.0, 1.0] but is
{
strength
}
"
)
if
height
%
8
!=
0
or
width
%
8
!=
0
:
raise
ValueError
(
f
"`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
."
)
if
(
callback_steps
is
None
)
or
(
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
)
):
raise
ValueError
(
f
"`callback_steps` has to be a positive integer but is
{
callback_steps
}
of type"
f
"
{
type
(
callback_steps
)
}
."
)
def
get_timesteps
(
self
,
num_inference_steps
,
strength
,
is_text2img
):
if
is_text2img
:
return
self
.
scheduler
.
timesteps
,
num_inference_steps
else
:
# get the original timestep using init_timestep
offset
=
self
.
scheduler
.
config
.
get
(
"steps_offset"
,
0
)
init_timestep
=
int
(
num_inference_steps
*
strength
)
+
offset
init_timestep
=
min
(
init_timestep
,
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
+
offset
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
return
timesteps
,
num_inference_steps
-
t_start
def
run_safety_checker
(
self
,
image
):
if
self
.
safety_checker
is
not
None
:
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"np"
).
pixel_values
.
astype
(
image
.
dtype
)
# There will throw an error if use safety_checker directly and batchsize>1
images
,
has_nsfw_concept
=
[],
[]
for
i
in
range
(
image
.
shape
[
0
]):
image_i
,
has_nsfw_concept_i
=
self
.
safety_checker
(
clip_input
=
safety_checker_input
[
i
:
i
+
1
],
images
=
image
[
i
:
i
+
1
]
)
images
.
append
(
image_i
)
has_nsfw_concept
.
append
(
has_nsfw_concept_i
[
0
])
image
=
np
.
concatenate
(
images
)
else
:
has_nsfw_concept
=
None
return
image
,
has_nsfw_concept
def
decode_latents
(
self
,
latents
):
latents
=
1
/
0.18215
*
latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image
=
np
.
concatenate
(
[
self
.
vae_decoder
(
latent_sample
=
latents
[
i
:
i
+
1
])[
0
]
for
i
in
range
(
latents
.
shape
[
0
])]
)
image
=
np
.
clip
(
image
/
2
+
0.5
,
0
,
1
)
image
=
image
.
transpose
((
0
,
2
,
3
,
1
))
return
image
def
prepare_extra_step_kwargs
(
self
,
generator
,
eta
):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
extra_step_kwargs
=
{}
if
accepts_eta
:
extra_step_kwargs
[
"eta"
]
=
eta
# check if the scheduler accepts generator
accepts_generator
=
"generator"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
if
accepts_generator
:
extra_step_kwargs
[
"generator"
]
=
generator
return
extra_step_kwargs
def
prepare_latents
(
self
,
image
,
timestep
,
batch_size
,
height
,
width
,
dtype
,
generator
,
latents
=
None
):
if
image
is
None
:
shape
=
(
batch_size
,
self
.
unet_in_channels
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
,
)
if
latents
is
None
:
latents
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
"cpu"
).
numpy
().
astype
(
dtype
)
else
:
if
latents
.
shape
!=
shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
shape
}
"
)
# scale the initial noise by the standard deviation required by the scheduler
latents
=
(
torch
.
from_numpy
(
latents
)
*
self
.
scheduler
.
init_noise_sigma
).
numpy
()
return
latents
,
None
,
None
else
:
init_latents
=
self
.
vae_encoder
(
sample
=
image
)[
0
]
init_latents
=
0.18215
*
init_latents
init_latents
=
np
.
concatenate
([
init_latents
]
*
batch_size
,
axis
=
0
)
init_latents_orig
=
init_latents
shape
=
init_latents
.
shape
# add noise to latents using the timesteps
noise
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
"cpu"
).
numpy
().
astype
(
dtype
)
latents
=
self
.
scheduler
.
add_noise
(
torch
.
from_numpy
(
init_latents
),
torch
.
from_numpy
(
noise
),
timestep
).
numpy
()
return
latents
,
init_latents_orig
,
noise
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
def
__call__
(
self
,
self
,
...
@@ -450,7 +591,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -450,7 +591,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
strength
:
float
=
0.8
,
strength
:
float
=
0.8
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
eta
:
float
=
0.0
,
eta
:
float
=
0.0
,
generator
:
Optional
[
np
.
random
.
RandomState
]
=
None
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
latents
:
Optional
[
np
.
ndarray
]
=
None
,
latents
:
Optional
[
np
.
ndarray
]
=
None
,
max_embeddings_multiples
:
Optional
[
int
]
=
3
,
max_embeddings_multiples
:
Optional
[
int
]
=
3
,
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
...
@@ -501,8 +642,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -501,8 +642,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0):
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
generator (`torch.Generator`, *optional*):
A np.random.RandomState to make generation deterministic.
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`np.ndarray`, *optional*):
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
...
@@ -537,204 +679,123 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -537,204 +679,123 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
init_image
=
deprecate
(
"init_image"
,
"0.12.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.12.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
image
=
init_image
or
image
if
isinstance
(
prompt
,
str
):
# 0. Default height and width to unet
batch_size
=
1
height
=
height
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
prompt
=
[
prompt
]
width
=
width
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
elif
isinstance
(
prompt
,
list
):
batch_size
=
len
(
prompt
)
else
:
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
if
strength
<
0
or
strength
>
1
:
raise
ValueError
(
f
"The value of strength should in [0.0, 1.0] but is
{
strength
}
"
)
if
height
%
8
!=
0
or
width
%
8
!=
0
:
# 1. Check inputs. Raise error if not correct
raise
ValueError
(
f
"`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
."
)
self
.
check_inputs
(
prompt
,
height
,
width
,
strength
,
callback_steps
)
if
(
callback_steps
is
None
)
or
(
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
)
):
raise
ValueError
(
f
"`callback_steps` has to be a positive integer but is
{
callback_steps
}
of type"
f
"
{
type
(
callback_steps
)
}
."
)
# get prompt text embeddings
# 2. Define call parameters
batch_size
=
1
if
isinstance
(
prompt
,
str
)
else
len
(
prompt
)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance
=
guidance_scale
>
1.0
do_classifier_free_guidance
=
guidance_scale
>
1.0
# get unconditional embeddings for classifier free guidance
if
negative_prompt
is
None
:
negative_prompt
=
[
""
]
*
batch_size
elif
isinstance
(
negative_prompt
,
str
):
negative_prompt
=
[
negative_prompt
]
*
batch_size
if
batch_size
!=
len
(
negative_prompt
):
raise
ValueError
(
f
"`negative_prompt`:
{
negative_prompt
}
has batch size
{
len
(
negative_prompt
)
}
, but `prompt`:"
f
"
{
prompt
}
has batch size
{
batch_size
}
. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
if
generator
is
None
:
# 3. Encode input prompt
generator
=
np
.
random
text_embeddings
=
self
.
_encode_prompt
(
prompt
,
text_embeddings
,
uncond_embeddings
=
get_weighted_text_embeddings
(
num_images_per_prompt
,
pipe
=
self
,
do_classifier_free_guidance
,
prompt
=
prompt
,
negative_prompt
,
uncond_prompt
=
negative_prompt
if
do_classifier_free_guidance
else
None
,
max_embeddings_multiples
,
max_embeddings_multiples
=
max_embeddings_multiples
,
**
kwargs
,
)
)
dtype
=
text_embeddings
.
dtype
text_embeddings
=
text_embeddings
.
repeat
(
num_images_per_prompt
,
0
)
if
do_classifier_free_guidance
:
# 4. Preprocess image and mask
uncond_embeddings
=
uncond_embeddings
.
repeat
(
num_images_per_prompt
,
0
)
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
text_embeddings
=
np
.
concatenate
([
uncond_embeddings
,
text_embeddings
])
image
=
preprocess_image
(
image
)
if
image
is
not
None
:
# set timesteps
image
=
image
.
astype
(
dtype
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
if
isinstance
(
mask_image
,
PIL
.
Image
.
Image
):
mask_image
=
preprocess_mask
(
mask_image
,
self
.
vae_scale_factor
)
latents_dtype
=
text_embeddings
.
dtype
if
mask_image
is
not
None
:
init_latents_orig
=
None
mask
=
mask_image
.
astype
(
dtype
)
mask
=
None
mask
=
np
.
concatenate
([
mask
]
*
batch_size
*
num_images_per_prompt
)
noise
=
None
if
image
is
None
:
latents_shape
=
(
batch_size
*
num_images_per_prompt
,
4
,
height
//
8
,
width
//
8
,
)
if
latents
is
None
:
latents
=
generator
.
randn
(
*
latents_shape
).
astype
(
latents_dtype
)
elif
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
timesteps
=
self
.
scheduler
.
timesteps
.
to
(
self
.
device
)
# scale the initial noise by the standard deviation required by the scheduler
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
else
:
else
:
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
mask
=
None
image
=
preprocess_image
(
image
)
# encode the init image into latents and scale the latents
image
=
image
.
astype
(
latents_dtype
)
init_latents
=
self
.
vae_encoder
(
sample
=
image
)[
0
]
init_latents
=
0.18215
*
init_latents
init_latents
=
np
.
concatenate
([
init_latents
]
*
batch_size
*
num_images_per_prompt
)
init_latents_orig
=
init_latents
# preprocess mask
if
mask_image
is
not
None
:
if
isinstance
(
mask_image
,
PIL
.
Image
.
Image
):
mask_image
=
preprocess_mask
(
mask_image
)
mask_image
=
mask_image
.
astype
(
latents_dtype
)
mask
=
np
.
concatenate
([
mask_image
]
*
batch_size
*
num_images_per_prompt
)
# check sizes
if
not
mask
.
shape
==
init_latents
.
shape
:
print
(
mask
.
shape
,
init_latents
.
shape
)
raise
ValueError
(
"The mask and image should be the same size!"
)
# get the original timestep using init_timestep
offset
=
self
.
scheduler
.
config
.
get
(
"steps_offset"
,
0
)
init_timestep
=
int
(
num_inference_steps
*
strength
)
+
offset
init_timestep
=
min
(
init_timestep
,
num_inference_steps
)
timesteps
=
self
.
scheduler
.
timesteps
[
-
init_timestep
]
# 5. set timesteps
timesteps
=
torch
.
tensor
([
timesteps
]
*
batch_size
*
num_images_per_prompt
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
timestep_dtype
=
next
(
# add noise to latents using the timesteps
(
input
.
type
for
input
in
self
.
unet
.
model
.
get_inputs
()
if
input
.
name
==
"timestep"
),
"tensor(float)"
noise
=
generator
.
randn
(
*
init_latents
.
shape
).
astype
(
latents_dtype
)
)
latents
=
self
.
scheduler
.
add_noise
(
timestep_dtype
=
ORT_TO_NP_TYPE
[
timestep_dtype
]
torch
.
from_numpy
(
init_latents
),
torch
.
from_numpy
(
noise
),
timesteps
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
num_inference_steps
,
strength
,
image
is
None
)
).
numpy
()
latent_timestep
=
timesteps
[:
1
].
repeat
(
batch_size
*
num_images_per_prompt
)
t_start
=
max
(
num_inference_steps
-
init_timestep
+
offset
,
0
)
# 6. Prepare latent variables
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
latents
,
init_latents_orig
,
noise
=
self
.
prepare_latents
(
image
,
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
latent_timestep
,
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
batch_size
*
num_images_per_prompt
,
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
height
,
# and should be between [0, 1]
width
,
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
dtype
,
extra_step_kwargs
=
{}
generator
,
if
accepts_eta
:
latents
,
extra_step_kwargs
[
"eta"
]
=
eta
)
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
)):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
np
.
concatenate
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
# predict the noise residual
noise_pred
=
self
.
unet
(
sample
=
latent_model_input
,
timestep
=
np
.
array
([
t
]),
encoder_hidden_states
=
text_embeddings
,
)
noise_pred
=
noise_pred
[
0
]
# perform guidance
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
np
.
split
(
noise_pred
,
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
.
numpy
()
if
mask
is
not
None
:
# masking
init_latents_proper
=
self
.
scheduler
.
add_noise
(
torch
.
from_numpy
(
init_latents_orig
),
torch
.
from_numpy
(
noise
),
torch
.
tensor
([
t
]),
).
numpy
()
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
# call the callback, if provided
if
i
%
callback_steps
==
0
:
if
callback
is
not
None
:
callback
(
i
,
t
,
latents
)
if
is_cancelled_callback
is
not
None
and
is_cancelled_callback
():
return
None
latents
=
1
/
0.18215
*
latents
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# image = self.vae_decoder(latent_sample=latents)[0]
extra_step_kwargs
=
self
.
prepare_extra_step_kwargs
(
generator
,
eta
)
# it seems likes there is a problem for using half-precision vae decoder if batchsize>1
image
=
[]
# 8. Denoising loop
for
i
in
range
(
latents
.
shape
[
0
]):
num_warmup_steps
=
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
image
.
append
(
self
.
vae_decoder
(
latent_sample
=
latents
[
i
:
i
+
1
])[
0
])
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
image
=
np
.
concatenate
(
image
)
for
i
,
t
in
enumerate
(
timesteps
):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
np
.
concatenate
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
torch
.
from_numpy
(
latent_model_input
),
t
)
latent_model_input
=
latent_model_input
.
numpy
()
# predict the noise residual
noise_pred
=
self
.
unet
(
sample
=
latent_model_input
,
timestep
=
np
.
array
([
t
],
dtype
=
timestep_dtype
),
encoder_hidden_states
=
text_embeddings
,
)
noise_pred
=
noise_pred
[
0
]
image
=
np
.
clip
(
image
/
2
+
0.5
,
0
,
1
)
# perform guidance
image
=
image
.
transpose
((
0
,
2
,
3
,
1
))
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
np
.
split
(
noise_pred
,
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
if
self
.
safety_checker
is
not
None
:
# compute the previous noisy sample x_t -> x_t-1
safety_checker_input
=
self
.
feature_extractor
(
scheduler_output
=
self
.
scheduler
.
step
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"np"
torch
.
from_numpy
(
noise_pred
),
t
,
torch
.
from_numpy
(
latents
),
**
extra_step_kwargs
).
pixel_values
.
astype
(
image
.
dtype
)
# There will throw an error if use safety_checker directly and batchsize>1
images
,
has_nsfw_concept
=
[],
[]
for
i
in
range
(
image
.
shape
[
0
]):
image_i
,
has_nsfw_concept_i
=
self
.
safety_checker
(
clip_input
=
safety_checker_input
[
i
:
i
+
1
],
images
=
image
[
i
:
i
+
1
]
)
)
images
.
append
(
image_i
)
latents
=
scheduler_output
.
prev_sample
.
numpy
()
has_nsfw_concept
.
append
(
has_nsfw_concept_i
[
0
])
image
=
np
.
concatenate
(
images
)
if
mask
is
not
None
:
else
:
# masking
has_nsfw_concept
=
None
init_latents_proper
=
self
.
scheduler
.
add_noise
(
torch
.
from_numpy
(
init_latents_orig
),
torch
.
from_numpy
(
noise
),
t
,
).
numpy
()
latents
=
(
init_latents_proper
*
mask
)
+
(
latents
*
(
1
-
mask
))
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
i
%
callback_steps
==
0
:
if
callback
is
not
None
:
callback
(
i
,
t
,
latents
)
if
is_cancelled_callback
is
not
None
and
is_cancelled_callback
():
return
None
# 9. Post-processing
image
=
self
.
decode_latents
(
latents
)
# 10. Run safety checker
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
)
# 11. Convert to PIL
if
output_type
==
"pil"
:
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
image
=
self
.
numpy_to_pil
(
image
)
if
not
return_dict
:
if
not
return_dict
:
return
(
image
,
has_nsfw_concept
)
return
image
,
has_nsfw_concept
return
StableDiffusionPipelineOutput
(
images
=
image
,
nsfw_content_detected
=
has_nsfw_concept
)
return
StableDiffusionPipelineOutput
(
images
=
image
,
nsfw_content_detected
=
has_nsfw_concept
)
...
@@ -748,7 +809,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -748,7 +809,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
guidance_scale
:
float
=
7.5
,
guidance_scale
:
float
=
7.5
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
eta
:
float
=
0.0
,
eta
:
float
=
0.0
,
generator
:
Optional
[
np
.
random
.
RandomState
]
=
None
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
latents
:
Optional
[
np
.
ndarray
]
=
None
,
latents
:
Optional
[
np
.
ndarray
]
=
None
,
max_embeddings_multiples
:
Optional
[
int
]
=
3
,
max_embeddings_multiples
:
Optional
[
int
]
=
3
,
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
...
@@ -783,8 +844,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -783,8 +844,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0):
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
generator (`torch.Generator`, *optional*):
A np.random.RandomState to make generation deterministic.
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`np.ndarray`, *optional*):
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
...
@@ -839,7 +901,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -839,7 +901,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
guidance_scale
:
Optional
[
float
]
=
7.5
,
guidance_scale
:
Optional
[
float
]
=
7.5
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
eta
:
Optional
[
float
]
=
0.0
,
eta
:
Optional
[
float
]
=
0.0
,
generator
:
Optional
[
np
.
random
.
RandomState
]
=
None
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
max_embeddings_multiples
:
Optional
[
int
]
=
3
,
max_embeddings_multiples
:
Optional
[
int
]
=
3
,
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
...
@@ -878,8 +940,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -878,8 +940,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0):
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
generator (`torch.Generator`, *optional*):
A np.random.RandomState to make generation deterministic.
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
The max multiple length of prompt embeddings compared to the max output length of text encoder.
output_type (`str`, *optional*, defaults to `"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
...
@@ -930,7 +993,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -930,7 +993,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
guidance_scale
:
Optional
[
float
]
=
7.5
,
guidance_scale
:
Optional
[
float
]
=
7.5
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
eta
:
Optional
[
float
]
=
0.0
,
eta
:
Optional
[
float
]
=
0.0
,
generator
:
Optional
[
np
.
random
.
RandomState
]
=
None
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
max_embeddings_multiples
:
Optional
[
int
]
=
3
,
max_embeddings_multiples
:
Optional
[
int
]
=
3
,
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
...
@@ -973,8 +1036,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
...
@@ -973,8 +1036,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0):
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
generator (`torch.Generator`, *optional*):
A np.random.RandomState to make generation deterministic.
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
The max multiple length of prompt embeddings compared to the max output length of text encoder.
output_type (`str`, *optional*, defaults to `"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment