Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
160c377d
Commit
160c377d
authored
May 30, 2023
by
Patrick von Platen
Browse files
Make style
parent
bb22d546
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
126 deletions
+70
-126
examples/community/clip_guided_images_mixing_stable_diffusion.py
...s/community/clip_guided_images_mixing_stable_diffusion.py
+70
-126
No files found.
examples/community/clip_guided_images_mixing_stable_diffusion.py
View file @
160c377d
...
...
@@ -32,8 +32,7 @@ def preprocess(image, w, h):
image
=
[
image
]
if
isinstance
(
image
[
0
],
PIL
.
Image
.
Image
):
image
=
[
np
.
array
(
i
.
resize
((
w
,
h
),
resample
=
PIL_INTERPOLATION
[
'lanczos'
]))[
None
,
:]
for
i
in
image
]
image
=
[
np
.
array
(
i
.
resize
((
w
,
h
),
resample
=
PIL_INTERPOLATION
[
"lanczos"
]))[
None
,
:]
for
i
in
image
]
image
=
np
.
concatenate
(
image
,
axis
=
0
)
image
=
np
.
array
(
image
).
astype
(
np
.
float32
)
/
255.0
image
=
image
.
transpose
(
0
,
3
,
1
,
2
)
...
...
@@ -45,7 +44,6 @@ def preprocess(image, w, h):
def
slerp
(
t
,
v0
,
v1
,
DOT_THRESHOLD
=
0.9995
):
if
not
isinstance
(
v0
,
np
.
ndarray
):
inputs_are_torch
=
True
input_device
=
v0
.
device
...
...
@@ -82,7 +80,6 @@ def set_requires_grad(model, value):
class
CLIPGuidedImagesMixingStableDiffusion
(
DiffusionPipeline
):
def
__init__
(
self
,
vae
:
AutoencoderKL
,
...
...
@@ -112,15 +109,14 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
self
.
feature_extractor_size
=
(
feature_extractor
.
size
if
isinstance
(
feature_extractor
.
size
,
int
)
else
feature_extractor
.
size
[
'
shortest_edge
'
]
else
feature_extractor
.
size
[
"
shortest_edge
"
]
)
self
.
normalize
=
transforms
.
Normalize
(
mean
=
feature_extractor
.
image_mean
,
std
=
feature_extractor
.
image_std
)
self
.
normalize
=
transforms
.
Normalize
(
mean
=
feature_extractor
.
image_mean
,
std
=
feature_extractor
.
image_std
)
set_requires_grad
(
self
.
text_encoder
,
False
)
set_requires_grad
(
self
.
clip_model
,
False
)
def
enable_attention_slicing
(
self
,
slice_size
:
Optional
[
Union
[
str
,
int
]]
=
'
auto
'
):
if
slice_size
==
'
auto
'
:
def
enable_attention_slicing
(
self
,
slice_size
:
Optional
[
Union
[
str
,
int
]]
=
"
auto
"
):
if
slice_size
==
"
auto
"
:
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size
=
self
.
unet
.
config
.
attention_head_dim
//
2
...
...
@@ -143,8 +139,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
def
get_timesteps
(
self
,
num_inference_steps
,
strength
,
device
):
# get the original timestep using init_timestep
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
...
...
@@ -153,15 +148,13 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
def
prepare_latents
(
self
,
image
,
timestep
,
batch_size
,
dtype
,
device
,
generator
=
None
):
if
not
isinstance
(
image
,
torch
.
Tensor
):
raise
ValueError
(
f
'`image` has to be of type `torch.Tensor` but is
{
type
(
image
)
}
'
)
raise
ValueError
(
f
"`image` has to be of type `torch.Tensor` but is
{
type
(
image
)
}
"
)
image
=
image
.
to
(
device
=
device
,
dtype
=
dtype
)
if
isinstance
(
generator
,
list
):
init_latents
=
[
self
.
vae
.
encode
(
image
[
i
:
i
+
1
]).
latent_dist
.
sample
(
generator
[
i
])
for
i
in
range
(
batch_size
)
self
.
vae
.
encode
(
image
[
i
:
i
+
1
]).
latent_dist
.
sample
(
generator
[
i
])
for
i
in
range
(
batch_size
)
]
init_latents
=
torch
.
cat
(
init_latents
,
dim
=
0
)
else
:
...
...
@@ -171,8 +164,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
init_latents
=
0.18215
*
init_latents
init_latents
=
init_latents
.
repeat_interleave
(
batch_size
,
dim
=
0
)
noise
=
randn_tensor
(
init_latents
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
noise
=
randn_tensor
(
init_latents
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
# get latents
init_latents
=
self
.
scheduler
.
add_noise
(
init_latents
,
noise
,
timestep
)
...
...
@@ -183,21 +175,16 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
def
get_image_description
(
self
,
image
):
transformed_image
=
self
.
coca_transform
(
image
).
unsqueeze
(
0
)
with
torch
.
no_grad
(),
torch
.
cuda
.
amp
.
autocast
():
generated
=
self
.
coca_model
.
generate
(
transformed_image
.
to
(
device
=
self
.
device
,
dtype
=
self
.
coca_model
.
dtype
))
generated
=
self
.
coca_model
.
generate
(
transformed_image
.
to
(
device
=
self
.
device
,
dtype
=
self
.
coca_model
.
dtype
))
generated
=
self
.
coca_tokenizer
.
decode
(
generated
[
0
].
cpu
().
numpy
())
return
generated
.
split
(
'
<end_of_text>
'
)[
0
].
replace
(
'
<start_of_text>
'
,
''
).
rstrip
(
'
.,
'
)
return
generated
.
split
(
"
<end_of_text>
"
)[
0
].
replace
(
"
<start_of_text>
"
,
""
).
rstrip
(
"
.,
"
)
def
get_clip_image_embeddings
(
self
,
image
,
batch_size
):
clip_image_input
=
self
.
feature_extractor
.
preprocess
(
image
)
clip_image_features
=
torch
.
from_numpy
(
clip_image_input
[
'pixel_values'
][
0
]).
unsqueeze
(
0
).
to
(
self
.
device
).
half
()
image_embeddings_clip
=
self
.
clip_model
.
get_image_features
(
clip_image_features
)
image_embeddings_clip
=
image_embeddings_clip
/
\
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
image_embeddings_clip
=
image_embeddings_clip
.
repeat_interleave
(
batch_size
,
dim
=
0
)
clip_image_features
=
torch
.
from_numpy
(
clip_image_input
[
"pixel_values"
][
0
]).
unsqueeze
(
0
).
to
(
self
.
device
).
half
()
image_embeddings_clip
=
self
.
clip_model
.
get_image_features
(
clip_image_features
)
image_embeddings_clip
=
image_embeddings_clip
/
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
image_embeddings_clip
=
image_embeddings_clip
.
repeat_interleave
(
batch_size
,
dim
=
0
)
return
image_embeddings_clip
@
torch
.
enable_grad
()
...
...
@@ -213,20 +200,17 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
):
latents
=
latents
.
detach
().
requires_grad_
()
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latents
,
timestep
)
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latents
,
timestep
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
timestep
,
encoder_hidden_states
=
text_embeddings
).
sample
noise_pred
=
self
.
unet
(
latent_model_input
,
timestep
,
encoder_hidden_states
=
text_embeddings
).
sample
if
isinstance
(
self
.
scheduler
,
(
PNDMScheduler
,
DDIMScheduler
,
DPMSolverMultistepScheduler
)):
alpha_prod_t
=
self
.
scheduler
.
alphas_cumprod
[
timestep
]
beta_prod_t
=
1
-
alpha_prod_t
# compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample
=
(
latents
-
beta_prod_t
**
(
0.5
)
*
noise_pred
)
/
alpha_prod_t
**
(
0.5
)
pred_original_sample
=
(
latents
-
beta_prod_t
**
(
0.5
)
*
noise_pred
)
/
alpha_prod_t
**
(
0.5
)
fac
=
torch
.
sqrt
(
beta_prod_t
)
sample
=
pred_original_sample
*
(
fac
)
+
latents
*
(
1
-
fac
)
...
...
@@ -234,8 +218,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
sigma
=
self
.
scheduler
.
sigmas
[
index
]
sample
=
latents
-
sigma
*
noise_pred
else
:
raise
ValueError
(
f
'scheduler type
{
type
(
self
.
scheduler
)
}
not supported'
)
raise
ValueError
(
f
"scheduler type
{
type
(
self
.
scheduler
)
}
not supported"
)
# Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
sample
=
1
/
0.18215
*
sample
...
...
@@ -246,11 +229,9 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
image
=
self
.
normalize
(
image
).
to
(
latents
.
dtype
)
image_embeddings_clip
=
self
.
clip_model
.
get_image_features
(
image
)
image_embeddings_clip
=
image_embeddings_clip
/
\
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
image_embeddings_clip
=
image_embeddings_clip
/
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
loss
=
spherical_dist_loss
(
image_embeddings_clip
,
original_image_embeddings_clip
).
mean
()
*
clip_guidance_scale
loss
=
spherical_dist_loss
(
image_embeddings_clip
,
original_image_embeddings_clip
).
mean
()
*
clip_guidance_scale
grads
=
-
torch
.
autograd
.
grad
(
loss
,
latents
)[
0
]
...
...
@@ -277,121 +258,101 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
eta
:
float
=
0.0
,
clip_guidance_scale
:
Optional
[
float
]
=
100
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
output_type
:
Optional
[
str
]
=
'
pil
'
,
output_type
:
Optional
[
str
]
=
"
pil
"
,
return_dict
:
bool
=
True
,
slerp_latent_style_strength
:
float
=
0.8
,
slerp_prompt_style_strength
:
float
=
0.1
,
slerp_clip_image_style_strength
:
float
=
0.1
,
):
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
raise
ValueError
(
f
'You have passed
{
batch_size
}
batch_size, but only
{
len
(
generator
)
}
generators.'
)
raise
ValueError
(
f
"You have passed
{
batch_size
}
batch_size, but only
{
len
(
generator
)
}
generators."
)
if
height
%
8
!=
0
or
width
%
8
!=
0
:
raise
ValueError
(
f
'`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
.'
)
raise
ValueError
(
f
"`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
."
)
if
isinstance
(
generator
,
torch
.
Generator
)
and
batch_size
>
1
:
generator
=
[
generator
]
+
[
None
]
*
(
batch_size
-
1
)
coca_is_none
=
[
(
'
model
'
,
self
.
coca_model
is
None
),
(
'
tokenizer
'
,
self
.
coca_tokenizer
is
None
),
(
'
transform
'
,
self
.
coca_transform
is
None
)
(
"
model
"
,
self
.
coca_model
is
None
),
(
"
tokenizer
"
,
self
.
coca_tokenizer
is
None
),
(
"
transform
"
,
self
.
coca_transform
is
None
)
,
]
coca_is_none
=
[
x
[
0
]
for
x
in
coca_is_none
if
x
[
1
]]
coca_is_none_str
=
'
,
'
.
join
(
coca_is_none
)
coca_is_none_str
=
"
,
"
.
join
(
coca_is_none
)
# generate prompts with coca model if prompt is None
if
content_prompt
is
None
:
if
len
(
coca_is_none
):
raise
ValueError
(
f
'
Content prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
'
f
'
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
'
f
"
Content prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
"
f
"
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
"
)
content_prompt
=
self
.
get_image_description
(
content_image
)
if
style_prompt
is
None
:
if
len
(
coca_is_none
):
raise
ValueError
(
f
'
Style prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
'
f
'
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
'
f
"
Style prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
"
f
"
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
"
)
style_prompt
=
self
.
get_image_description
(
style_image
)
# get prompt text embeddings for content and style
content_text_input
=
self
.
tokenizer
(
content_prompt
,
padding
=
'
max_length
'
,
padding
=
"
max_length
"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
'
pt
'
,
return_tensors
=
"
pt
"
,
)
content_text_embeddings
=
self
.
text_encoder
(
content_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
content_text_embeddings
=
self
.
text_encoder
(
content_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
style_text_input
=
self
.
tokenizer
(
style_prompt
,
padding
=
'
max_length
'
,
padding
=
"
max_length
"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
'
pt
'
,
return_tensors
=
"
pt
"
,
)
style_text_embeddings
=
self
.
text_encoder
(
style_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
style_text_embeddings
=
self
.
text_encoder
(
style_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
text_embeddings
=
slerp
(
slerp_prompt_style_strength
,
content_text_embeddings
,
style_text_embeddings
)
text_embeddings
=
slerp
(
slerp_prompt_style_strength
,
content_text_embeddings
,
style_text_embeddings
)
# duplicate text embeddings for each generation per prompt
text_embeddings
=
text_embeddings
.
repeat_interleave
(
batch_size
,
dim
=
0
)
# set timesteps
accepts_offset
=
'offset'
in
set
(
inspect
.
signature
(
self
.
scheduler
.
set_timesteps
).
parameters
.
keys
())
accepts_offset
=
"offset"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
set_timesteps
).
parameters
.
keys
())
extra_set_kwargs
=
{}
if
accepts_offset
:
extra_set_kwargs
[
'
offset
'
]
=
1
extra_set_kwargs
[
"
offset
"
]
=
1
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
**
extra_set_kwargs
)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
self
.
scheduler
.
timesteps
.
to
(
self
.
device
)
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
num_inference_steps
,
noise_strength
,
self
.
device
)
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
num_inference_steps
,
noise_strength
,
self
.
device
)
latent_timestep
=
timesteps
[:
1
].
repeat
(
batch_size
)
# Preprocess image
preprocessed_content_image
=
preprocess
(
content_image
,
width
,
height
)
content_latents
=
self
.
prepare_latents
(
preprocessed_content_image
,
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
preprocessed_content_image
,
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
)
preprocessed_style_image
=
preprocess
(
style_image
,
width
,
height
)
style_latents
=
self
.
prepare_latents
(
preprocessed_style_image
,
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
preprocessed_style_image
,
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
)
latents
=
slerp
(
slerp_latent_style_strength
,
content_latents
,
style_latents
)
latents
=
slerp
(
slerp_latent_style_strength
,
content_latents
,
style_latents
)
if
clip_guidance_scale
>
0
:
content_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
content_image
,
batch_size
)
style_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
style_image
,
batch_size
)
content_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
content_image
,
batch_size
)
style_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
style_image
,
batch_size
)
clip_image_embeddings
=
slerp
(
slerp_clip_image_style_strength
,
content_clip_image_embedding
,
style_clip_image_embedding
)
slerp_clip_image_style_strength
,
content_clip_image_embedding
,
style_clip_image_embedding
)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...
...
@@ -400,13 +361,10 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance
if
do_classifier_free_guidance
:
max_length
=
content_text_input
.
input_ids
.
shape
[
-
1
]
uncond_input
=
self
.
tokenizer
(
[
''
],
padding
=
'max_length'
,
max_length
=
max_length
,
return_tensors
=
'pt'
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
max_length
,
return_tensors
=
"pt"
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings
=
uncond_embeddings
.
repeat_interleave
(
batch_size
,
dim
=
0
)
uncond_embeddings
=
uncond_embeddings
.
repeat_interleave
(
batch_size
,
dim
=
0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
...
...
@@ -418,25 +376,19 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape
=
(
batch_size
,
self
.
unet
.
config
.
in_channels
,
height
//
8
,
width
//
8
)
latents_shape
=
(
batch_size
,
self
.
unet
.
config
.
in_channels
,
height
//
8
,
width
//
8
)
latents_dtype
=
text_embeddings
.
dtype
if
latents
is
None
:
if
self
.
device
.
type
==
'
mps
'
:
if
self
.
device
.
type
==
"
mps
"
:
# randn does not work reproducibly on mps
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
'cpu'
,
dtype
=
latents_dtype
).
to
(
self
.
device
)
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
latents_dtype
).
to
(
self
.
device
)
else
:
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
)
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
)
else
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
'Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
'
)
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
latents
=
latents
.
to
(
self
.
device
)
# scale the initial noise by the standard deviation required by the scheduler
...
...
@@ -446,41 +398,34 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta
=
'eta'
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
extra_step_kwargs
=
{}
if
accepts_eta
:
extra_step_kwargs
[
'
eta
'
]
=
eta
extra_step_kwargs
[
"
eta
"
]
=
eta
# check if the scheduler accepts generator
accepts_generator
=
'generator'
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
accepts_generator
=
"generator"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
if
accepts_generator
:
extra_step_kwargs
[
'
generator
'
]
=
generator
extra_step_kwargs
[
"
generator
"
]
=
generator
with
self
.
progress_bar
(
total
=
num_inference_steps
):
for
i
,
t
in
enumerate
(
timesteps
):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
(
[
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
# perform classifier free guidance
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
\
(
noise_pred_text
-
noise_pred_uncond
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# perform clip guidance
if
clip_guidance_scale
>
0
:
text_embeddings_for_guidance
=
(
text_embeddings
.
chunk
(
2
)[
1
]
if
do_classifier_free_guidance
else
text_embeddings
text_embeddings
.
chunk
(
2
)[
1
]
if
do_classifier_free_guidance
else
text_embeddings
)
noise_pred
,
latents
=
self
.
cond_fn
(
latents
,
...
...
@@ -493,8 +438,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
latents
=
1
/
0.18215
*
latents
...
...
@@ -503,7 +447,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
'
pil
'
:
if
output_type
==
"
pil
"
:
image
=
self
.
numpy_to_pil
(
image
)
if
not
return_dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment