Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
160c377d
Commit
160c377d
authored
May 30, 2023
by
Patrick von Platen
Browse files
Make style
parent
bb22d546
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
126 deletions
+70
-126
examples/community/clip_guided_images_mixing_stable_diffusion.py
...s/community/clip_guided_images_mixing_stable_diffusion.py
+70
-126
No files found.
examples/community/clip_guided_images_mixing_stable_diffusion.py
View file @
160c377d
...
@@ -32,8 +32,7 @@ def preprocess(image, w, h):
...
@@ -32,8 +32,7 @@ def preprocess(image, w, h):
image
=
[
image
]
image
=
[
image
]
if
isinstance
(
image
[
0
],
PIL
.
Image
.
Image
):
if
isinstance
(
image
[
0
],
PIL
.
Image
.
Image
):
image
=
[
np
.
array
(
i
.
resize
((
w
,
h
),
resample
=
PIL_INTERPOLATION
[
'lanczos'
]))[
image
=
[
np
.
array
(
i
.
resize
((
w
,
h
),
resample
=
PIL_INTERPOLATION
[
"lanczos"
]))[
None
,
:]
for
i
in
image
]
None
,
:]
for
i
in
image
]
image
=
np
.
concatenate
(
image
,
axis
=
0
)
image
=
np
.
concatenate
(
image
,
axis
=
0
)
image
=
np
.
array
(
image
).
astype
(
np
.
float32
)
/
255.0
image
=
np
.
array
(
image
).
astype
(
np
.
float32
)
/
255.0
image
=
image
.
transpose
(
0
,
3
,
1
,
2
)
image
=
image
.
transpose
(
0
,
3
,
1
,
2
)
...
@@ -45,7 +44,6 @@ def preprocess(image, w, h):
...
@@ -45,7 +44,6 @@ def preprocess(image, w, h):
def
slerp
(
t
,
v0
,
v1
,
DOT_THRESHOLD
=
0.9995
):
def
slerp
(
t
,
v0
,
v1
,
DOT_THRESHOLD
=
0.9995
):
if
not
isinstance
(
v0
,
np
.
ndarray
):
if
not
isinstance
(
v0
,
np
.
ndarray
):
inputs_are_torch
=
True
inputs_are_torch
=
True
input_device
=
v0
.
device
input_device
=
v0
.
device
...
@@ -82,7 +80,6 @@ def set_requires_grad(model, value):
...
@@ -82,7 +80,6 @@ def set_requires_grad(model, value):
class
CLIPGuidedImagesMixingStableDiffusion
(
DiffusionPipeline
):
class
CLIPGuidedImagesMixingStableDiffusion
(
DiffusionPipeline
):
def
__init__
(
def
__init__
(
self
,
self
,
vae
:
AutoencoderKL
,
vae
:
AutoencoderKL
,
...
@@ -112,15 +109,14 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -112,15 +109,14 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
self
.
feature_extractor_size
=
(
self
.
feature_extractor_size
=
(
feature_extractor
.
size
feature_extractor
.
size
if
isinstance
(
feature_extractor
.
size
,
int
)
if
isinstance
(
feature_extractor
.
size
,
int
)
else
feature_extractor
.
size
[
'
shortest_edge
'
]
else
feature_extractor
.
size
[
"
shortest_edge
"
]
)
)
self
.
normalize
=
transforms
.
Normalize
(
self
.
normalize
=
transforms
.
Normalize
(
mean
=
feature_extractor
.
image_mean
,
std
=
feature_extractor
.
image_std
)
mean
=
feature_extractor
.
image_mean
,
std
=
feature_extractor
.
image_std
)
set_requires_grad
(
self
.
text_encoder
,
False
)
set_requires_grad
(
self
.
text_encoder
,
False
)
set_requires_grad
(
self
.
clip_model
,
False
)
set_requires_grad
(
self
.
clip_model
,
False
)
def
enable_attention_slicing
(
self
,
slice_size
:
Optional
[
Union
[
str
,
int
]]
=
'
auto
'
):
def
enable_attention_slicing
(
self
,
slice_size
:
Optional
[
Union
[
str
,
int
]]
=
"
auto
"
):
if
slice_size
==
'
auto
'
:
if
slice_size
==
"
auto
"
:
# half the attention head size is usually a good trade-off between
# half the attention head size is usually a good trade-off between
# speed and memory
# speed and memory
slice_size
=
self
.
unet
.
config
.
attention_head_dim
//
2
slice_size
=
self
.
unet
.
config
.
attention_head_dim
//
2
...
@@ -143,8 +139,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -143,8 +139,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
def
get_timesteps
(
self
,
num_inference_steps
,
strength
,
device
):
def
get_timesteps
(
self
,
num_inference_steps
,
strength
,
device
):
# get the original timestep using init_timestep
# get the original timestep using init_timestep
init_timestep
=
min
(
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
...
@@ -153,15 +148,13 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -153,15 +148,13 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
def
prepare_latents
(
self
,
image
,
timestep
,
batch_size
,
dtype
,
device
,
generator
=
None
):
def
prepare_latents
(
self
,
image
,
timestep
,
batch_size
,
dtype
,
device
,
generator
=
None
):
if
not
isinstance
(
image
,
torch
.
Tensor
):
if
not
isinstance
(
image
,
torch
.
Tensor
):
raise
ValueError
(
raise
ValueError
(
f
"`image` has to be of type `torch.Tensor` but is
{
type
(
image
)
}
"
)
f
'`image` has to be of type `torch.Tensor` but is
{
type
(
image
)
}
'
)
image
=
image
.
to
(
device
=
device
,
dtype
=
dtype
)
image
=
image
.
to
(
device
=
device
,
dtype
=
dtype
)
if
isinstance
(
generator
,
list
):
if
isinstance
(
generator
,
list
):
init_latents
=
[
init_latents
=
[
self
.
vae
.
encode
(
image
[
i
:
i
+
1
]).
latent_dist
.
sample
(
generator
[
i
])
for
i
in
range
(
batch_size
)
self
.
vae
.
encode
(
image
[
i
:
i
+
1
]).
latent_dist
.
sample
(
generator
[
i
])
for
i
in
range
(
batch_size
)
]
]
init_latents
=
torch
.
cat
(
init_latents
,
dim
=
0
)
init_latents
=
torch
.
cat
(
init_latents
,
dim
=
0
)
else
:
else
:
...
@@ -171,8 +164,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -171,8 +164,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
init_latents
=
0.18215
*
init_latents
init_latents
=
0.18215
*
init_latents
init_latents
=
init_latents
.
repeat_interleave
(
batch_size
,
dim
=
0
)
init_latents
=
init_latents
.
repeat_interleave
(
batch_size
,
dim
=
0
)
noise
=
randn_tensor
(
init_latents
.
shape
,
noise
=
randn_tensor
(
init_latents
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
# get latents
# get latents
init_latents
=
self
.
scheduler
.
add_noise
(
init_latents
,
noise
,
timestep
)
init_latents
=
self
.
scheduler
.
add_noise
(
init_latents
,
noise
,
timestep
)
...
@@ -183,21 +175,16 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -183,21 +175,16 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
def
get_image_description
(
self
,
image
):
def
get_image_description
(
self
,
image
):
transformed_image
=
self
.
coca_transform
(
image
).
unsqueeze
(
0
)
transformed_image
=
self
.
coca_transform
(
image
).
unsqueeze
(
0
)
with
torch
.
no_grad
(),
torch
.
cuda
.
amp
.
autocast
():
with
torch
.
no_grad
(),
torch
.
cuda
.
amp
.
autocast
():
generated
=
self
.
coca_model
.
generate
(
transformed_image
.
to
(
generated
=
self
.
coca_model
.
generate
(
transformed_image
.
to
(
device
=
self
.
device
,
dtype
=
self
.
coca_model
.
dtype
))
device
=
self
.
device
,
dtype
=
self
.
coca_model
.
dtype
))
generated
=
self
.
coca_tokenizer
.
decode
(
generated
[
0
].
cpu
().
numpy
())
generated
=
self
.
coca_tokenizer
.
decode
(
generated
[
0
].
cpu
().
numpy
())
return
generated
.
split
(
'
<end_of_text>
'
)[
0
].
replace
(
'
<start_of_text>
'
,
''
).
rstrip
(
'
.,
'
)
return
generated
.
split
(
"
<end_of_text>
"
)[
0
].
replace
(
"
<start_of_text>
"
,
""
).
rstrip
(
"
.,
"
)
def
get_clip_image_embeddings
(
self
,
image
,
batch_size
):
def
get_clip_image_embeddings
(
self
,
image
,
batch_size
):
clip_image_input
=
self
.
feature_extractor
.
preprocess
(
image
)
clip_image_input
=
self
.
feature_extractor
.
preprocess
(
image
)
clip_image_features
=
torch
.
from_numpy
(
clip_image_features
=
torch
.
from_numpy
(
clip_image_input
[
"pixel_values"
][
0
]).
unsqueeze
(
0
).
to
(
self
.
device
).
half
()
clip_image_input
[
'pixel_values'
][
0
]).
unsqueeze
(
0
).
to
(
self
.
device
).
half
()
image_embeddings_clip
=
self
.
clip_model
.
get_image_features
(
clip_image_features
)
image_embeddings_clip
=
self
.
clip_model
.
get_image_features
(
image_embeddings_clip
=
image_embeddings_clip
/
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
clip_image_features
)
image_embeddings_clip
=
image_embeddings_clip
.
repeat_interleave
(
batch_size
,
dim
=
0
)
image_embeddings_clip
=
image_embeddings_clip
/
\
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
image_embeddings_clip
=
image_embeddings_clip
.
repeat_interleave
(
batch_size
,
dim
=
0
)
return
image_embeddings_clip
return
image_embeddings_clip
@
torch
.
enable_grad
()
@
torch
.
enable_grad
()
...
@@ -213,20 +200,17 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -213,20 +200,17 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
):
):
latents
=
latents
.
detach
().
requires_grad_
()
latents
=
latents
.
detach
().
requires_grad_
()
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latents
,
timestep
)
latents
,
timestep
)
# predict the noise residual
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
timestep
,
noise_pred
=
self
.
unet
(
latent_model_input
,
timestep
,
encoder_hidden_states
=
text_embeddings
).
sample
encoder_hidden_states
=
text_embeddings
).
sample
if
isinstance
(
self
.
scheduler
,
(
PNDMScheduler
,
DDIMScheduler
,
DPMSolverMultistepScheduler
)):
if
isinstance
(
self
.
scheduler
,
(
PNDMScheduler
,
DDIMScheduler
,
DPMSolverMultistepScheduler
)):
alpha_prod_t
=
self
.
scheduler
.
alphas_cumprod
[
timestep
]
alpha_prod_t
=
self
.
scheduler
.
alphas_cumprod
[
timestep
]
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t
=
1
-
alpha_prod_t
# compute predicted original sample from predicted noise also called
# compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample
=
(
pred_original_sample
=
(
latents
-
beta_prod_t
**
(
0.5
)
*
noise_pred
)
/
alpha_prod_t
**
(
0.5
)
latents
-
beta_prod_t
**
(
0.5
)
*
noise_pred
)
/
alpha_prod_t
**
(
0.5
)
fac
=
torch
.
sqrt
(
beta_prod_t
)
fac
=
torch
.
sqrt
(
beta_prod_t
)
sample
=
pred_original_sample
*
(
fac
)
+
latents
*
(
1
-
fac
)
sample
=
pred_original_sample
*
(
fac
)
+
latents
*
(
1
-
fac
)
...
@@ -234,8 +218,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -234,8 +218,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
sigma
=
self
.
scheduler
.
sigmas
[
index
]
sigma
=
self
.
scheduler
.
sigmas
[
index
]
sample
=
latents
-
sigma
*
noise_pred
sample
=
latents
-
sigma
*
noise_pred
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"scheduler type
{
type
(
self
.
scheduler
)
}
not supported"
)
f
'scheduler type
{
type
(
self
.
scheduler
)
}
not supported'
)
# Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
# Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
sample
=
1
/
0.18215
*
sample
sample
=
1
/
0.18215
*
sample
...
@@ -246,11 +229,9 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -246,11 +229,9 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
image
=
self
.
normalize
(
image
).
to
(
latents
.
dtype
)
image
=
self
.
normalize
(
image
).
to
(
latents
.
dtype
)
image_embeddings_clip
=
self
.
clip_model
.
get_image_features
(
image
)
image_embeddings_clip
=
self
.
clip_model
.
get_image_features
(
image
)
image_embeddings_clip
=
image_embeddings_clip
/
\
image_embeddings_clip
=
image_embeddings_clip
/
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
loss
=
spherical_dist_loss
(
loss
=
spherical_dist_loss
(
image_embeddings_clip
,
original_image_embeddings_clip
).
mean
()
*
clip_guidance_scale
image_embeddings_clip
,
original_image_embeddings_clip
).
mean
()
*
clip_guidance_scale
grads
=
-
torch
.
autograd
.
grad
(
loss
,
latents
)[
0
]
grads
=
-
torch
.
autograd
.
grad
(
loss
,
latents
)[
0
]
...
@@ -277,121 +258,101 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -277,121 +258,101 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
eta
:
float
=
0.0
,
eta
:
float
=
0.0
,
clip_guidance_scale
:
Optional
[
float
]
=
100
,
clip_guidance_scale
:
Optional
[
float
]
=
100
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
output_type
:
Optional
[
str
]
=
'
pil
'
,
output_type
:
Optional
[
str
]
=
"
pil
"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
slerp_latent_style_strength
:
float
=
0.8
,
slerp_latent_style_strength
:
float
=
0.8
,
slerp_prompt_style_strength
:
float
=
0.1
,
slerp_prompt_style_strength
:
float
=
0.1
,
slerp_clip_image_style_strength
:
float
=
0.1
,
slerp_clip_image_style_strength
:
float
=
0.1
,
):
):
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
raise
ValueError
(
raise
ValueError
(
f
"You have passed
{
batch_size
}
batch_size, but only
{
len
(
generator
)
}
generators."
)
f
'You have passed
{
batch_size
}
batch_size, but only
{
len
(
generator
)
}
generators.'
)
if
height
%
8
!=
0
or
width
%
8
!=
0
:
if
height
%
8
!=
0
or
width
%
8
!=
0
:
raise
ValueError
(
raise
ValueError
(
f
"`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
."
)
f
'`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
.'
)
if
isinstance
(
generator
,
torch
.
Generator
)
and
batch_size
>
1
:
if
isinstance
(
generator
,
torch
.
Generator
)
and
batch_size
>
1
:
generator
=
[
generator
]
+
[
None
]
*
(
batch_size
-
1
)
generator
=
[
generator
]
+
[
None
]
*
(
batch_size
-
1
)
coca_is_none
=
[
coca_is_none
=
[
(
'
model
'
,
self
.
coca_model
is
None
),
(
"
model
"
,
self
.
coca_model
is
None
),
(
'
tokenizer
'
,
self
.
coca_tokenizer
is
None
),
(
"
tokenizer
"
,
self
.
coca_tokenizer
is
None
),
(
'
transform
'
,
self
.
coca_transform
is
None
)
(
"
transform
"
,
self
.
coca_transform
is
None
)
,
]
]
coca_is_none
=
[
x
[
0
]
for
x
in
coca_is_none
if
x
[
1
]]
coca_is_none
=
[
x
[
0
]
for
x
in
coca_is_none
if
x
[
1
]]
coca_is_none_str
=
'
,
'
.
join
(
coca_is_none
)
coca_is_none_str
=
"
,
"
.
join
(
coca_is_none
)
# generate prompts with coca model if prompt is None
# generate prompts with coca model if prompt is None
if
content_prompt
is
None
:
if
content_prompt
is
None
:
if
len
(
coca_is_none
):
if
len
(
coca_is_none
):
raise
ValueError
(
raise
ValueError
(
f
'
Content prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
'
f
"
Content prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
"
f
'
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
'
f
"
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
"
)
)
content_prompt
=
self
.
get_image_description
(
content_image
)
content_prompt
=
self
.
get_image_description
(
content_image
)
if
style_prompt
is
None
:
if
style_prompt
is
None
:
if
len
(
coca_is_none
):
if
len
(
coca_is_none
):
raise
ValueError
(
raise
ValueError
(
f
'
Style prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
'
f
"
Style prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
"
f
'
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
'
f
"
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
"
)
)
style_prompt
=
self
.
get_image_description
(
style_image
)
style_prompt
=
self
.
get_image_description
(
style_image
)
# get prompt text embeddings for content and style
# get prompt text embeddings for content and style
content_text_input
=
self
.
tokenizer
(
content_text_input
=
self
.
tokenizer
(
content_prompt
,
content_prompt
,
padding
=
'
max_length
'
,
padding
=
"
max_length
"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
truncation
=
True
,
return_tensors
=
'
pt
'
,
return_tensors
=
"
pt
"
,
)
)
content_text_embeddings
=
self
.
text_encoder
(
content_text_embeddings
=
self
.
text_encoder
(
content_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
content_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
style_text_input
=
self
.
tokenizer
(
style_text_input
=
self
.
tokenizer
(
style_prompt
,
style_prompt
,
padding
=
'
max_length
'
,
padding
=
"
max_length
"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
truncation
=
True
,
return_tensors
=
'
pt
'
,
return_tensors
=
"
pt
"
,
)
)
style_text_embeddings
=
self
.
text_encoder
(
style_text_embeddings
=
self
.
text_encoder
(
style_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
style_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
text_embeddings
=
slerp
(
text_embeddings
=
slerp
(
slerp_prompt_style_strength
,
content_text_embeddings
,
style_text_embeddings
)
slerp_prompt_style_strength
,
content_text_embeddings
,
style_text_embeddings
)
# duplicate text embeddings for each generation per prompt
# duplicate text embeddings for each generation per prompt
text_embeddings
=
text_embeddings
.
repeat_interleave
(
batch_size
,
dim
=
0
)
text_embeddings
=
text_embeddings
.
repeat_interleave
(
batch_size
,
dim
=
0
)
# set timesteps
# set timesteps
accepts_offset
=
'offset'
in
set
(
inspect
.
signature
(
accepts_offset
=
"offset"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
set_timesteps
).
parameters
.
keys
())
self
.
scheduler
.
set_timesteps
).
parameters
.
keys
())
extra_set_kwargs
=
{}
extra_set_kwargs
=
{}
if
accepts_offset
:
if
accepts_offset
:
extra_set_kwargs
[
'
offset
'
]
=
1
extra_set_kwargs
[
"
offset
"
]
=
1
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
**
extra_set_kwargs
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
**
extra_set_kwargs
)
# Some schedulers like PNDM have timesteps as arrays
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
# It's more optimized to move all timesteps to correct device beforehand
self
.
scheduler
.
timesteps
.
to
(
self
.
device
)
self
.
scheduler
.
timesteps
.
to
(
self
.
device
)
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
num_inference_steps
,
noise_strength
,
self
.
device
)
num_inference_steps
,
noise_strength
,
self
.
device
)
latent_timestep
=
timesteps
[:
1
].
repeat
(
batch_size
)
latent_timestep
=
timesteps
[:
1
].
repeat
(
batch_size
)
# Preprocess image
# Preprocess image
preprocessed_content_image
=
preprocess
(
content_image
,
width
,
height
)
preprocessed_content_image
=
preprocess
(
content_image
,
width
,
height
)
content_latents
=
self
.
prepare_latents
(
content_latents
=
self
.
prepare_latents
(
preprocessed_content_image
,
preprocessed_content_image
,
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
)
)
preprocessed_style_image
=
preprocess
(
style_image
,
width
,
height
)
preprocessed_style_image
=
preprocess
(
style_image
,
width
,
height
)
style_latents
=
self
.
prepare_latents
(
style_latents
=
self
.
prepare_latents
(
preprocessed_style_image
,
preprocessed_style_image
,
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
)
)
latents
=
slerp
(
slerp_latent_style_strength
,
latents
=
slerp
(
slerp_latent_style_strength
,
content_latents
,
style_latents
)
content_latents
,
style_latents
)
if
clip_guidance_scale
>
0
:
if
clip_guidance_scale
>
0
:
content_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
content_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
content_image
,
batch_size
)
content_image
,
batch_size
)
style_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
style_image
,
batch_size
)
style_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
style_image
,
batch_size
)
clip_image_embeddings
=
slerp
(
clip_image_embeddings
=
slerp
(
slerp_clip_image_style_strength
,
content_clip_image_embedding
,
style_clip_image_embedding
)
slerp_clip_image_style_strength
,
content_clip_image_embedding
,
style_clip_image_embedding
)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...
@@ -400,13 +361,10 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -400,13 +361,10 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance
# get unconditional embeddings for classifier free guidance
if
do_classifier_free_guidance
:
if
do_classifier_free_guidance
:
max_length
=
content_text_input
.
input_ids
.
shape
[
-
1
]
max_length
=
content_text_input
.
input_ids
.
shape
[
-
1
]
uncond_input
=
self
.
tokenizer
(
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
max_length
,
return_tensors
=
"pt"
)
[
''
],
padding
=
'max_length'
,
max_length
=
max_length
,
return_tensors
=
'pt'
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
# duplicate unconditional embeddings for each generation per prompt
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings
=
uncond_embeddings
.
repeat_interleave
(
uncond_embeddings
=
uncond_embeddings
.
repeat_interleave
(
batch_size
,
dim
=
0
)
batch_size
,
dim
=
0
)
# For classifier free guidance, we need to do two forward passes.
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# Here we concatenate the unconditional and text embeddings into a single batch
...
@@ -418,25 +376,19 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -418,25 +376,19 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
# However this currently doesn't work in `mps`.
latents_shape
=
(
latents_shape
=
(
batch_size
,
self
.
unet
.
config
.
in_channels
,
height
//
8
,
width
//
8
)
batch_size
,
self
.
unet
.
config
.
in_channels
,
height
//
8
,
width
//
8
)
latents_dtype
=
text_embeddings
.
dtype
latents_dtype
=
text_embeddings
.
dtype
if
latents
is
None
:
if
latents
is
None
:
if
self
.
device
.
type
==
'
mps
'
:
if
self
.
device
.
type
==
"
mps
"
:
# randn does not work reproducibly on mps
# randn does not work reproducibly on mps
latents
=
torch
.
randn
(
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
latents_dtype
).
to
(
latents_shape
,
self
.
device
generator
=
generator
,
)
device
=
'cpu'
,
dtype
=
latents_dtype
).
to
(
self
.
device
)
else
:
else
:
latents
=
torch
.
randn
(
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
)
latents_shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
)
else
:
else
:
if
latents
.
shape
!=
latents_shape
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
f
'Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
'
)
latents
=
latents
.
to
(
self
.
device
)
latents
=
latents
.
to
(
self
.
device
)
# scale the initial noise by the standard deviation required by the scheduler
# scale the initial noise by the standard deviation required by the scheduler
...
@@ -446,41 +398,34 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -446,41 +398,34 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
# and should be between [0, 1]
accepts_eta
=
'eta'
in
set
(
inspect
.
signature
(
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
self
.
scheduler
.
step
).
parameters
.
keys
())
extra_step_kwargs
=
{}
extra_step_kwargs
=
{}
if
accepts_eta
:
if
accepts_eta
:
extra_step_kwargs
[
'
eta
'
]
=
eta
extra_step_kwargs
[
"
eta
"
]
=
eta
# check if the scheduler accepts generator
# check if the scheduler accepts generator
accepts_generator
=
'generator'
in
set
(
accepts_generator
=
"generator"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
if
accepts_generator
:
if
accepts_generator
:
extra_step_kwargs
[
'
generator
'
]
=
generator
extra_step_kwargs
[
"
generator
"
]
=
generator
with
self
.
progress_bar
(
total
=
num_inference_steps
):
with
self
.
progress_bar
(
total
=
num_inference_steps
):
for
i
,
t
in
enumerate
(
timesteps
):
for
i
,
t
in
enumerate
(
timesteps
):
# expand the latents if we are doing classifier free guidance
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
(
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
[
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
# predict the noise residual
# predict the noise residual
noise_pred
=
self
.
unet
(
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
# perform classifier free guidance
# perform classifier free guidance
if
do_classifier_free_guidance
:
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
\
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
(
noise_pred_text
-
noise_pred_uncond
)
# perform clip guidance
# perform clip guidance
if
clip_guidance_scale
>
0
:
if
clip_guidance_scale
>
0
:
text_embeddings_for_guidance
=
(
text_embeddings_for_guidance
=
(
text_embeddings
.
chunk
(
text_embeddings
.
chunk
(
2
)[
1
]
if
do_classifier_free_guidance
else
text_embeddings
2
)[
1
]
if
do_classifier_free_guidance
else
text_embeddings
)
)
noise_pred
,
latents
=
self
.
cond_fn
(
noise_pred
,
latents
=
self
.
cond_fn
(
latents
,
latents
,
...
@@ -493,8 +438,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -493,8 +438,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
)
)
# compute the previous noisy sample x_t -> x_t-1
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
# Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
latents
=
1
/
0.18215
*
latents
latents
=
1
/
0.18215
*
latents
...
@@ -503,7 +447,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
...
@@ -503,7 +447,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
'
pil
'
:
if
output_type
==
"
pil
"
:
image
=
self
.
numpy_to_pil
(
image
)
image
=
self
.
numpy_to_pil
(
image
)
if
not
return_dict
:
if
not
return_dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment