Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
160c377d
Commit
160c377d
authored
May 30, 2023
by
Patrick von Platen
Browse files
Make style
parent
bb22d546
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
126 deletions
+70
-126
examples/community/clip_guided_images_mixing_stable_diffusion.py
...s/community/clip_guided_images_mixing_stable_diffusion.py
+70
-126
No files found.
examples/community/clip_guided_images_mixing_stable_diffusion.py
View file @
160c377d
...
...
@@ -32,8 +32,7 @@ def preprocess(image, w, h):
image
=
[
image
]
if
isinstance
(
image
[
0
],
PIL
.
Image
.
Image
):
image
=
[
np
.
array
(
i
.
resize
((
w
,
h
),
resample
=
PIL_INTERPOLATION
[
'lanczos'
]))[
None
,
:]
for
i
in
image
]
image
=
[
np
.
array
(
i
.
resize
((
w
,
h
),
resample
=
PIL_INTERPOLATION
[
"lanczos"
]))[
None
,
:]
for
i
in
image
]
image
=
np
.
concatenate
(
image
,
axis
=
0
)
image
=
np
.
array
(
image
).
astype
(
np
.
float32
)
/
255.0
image
=
image
.
transpose
(
0
,
3
,
1
,
2
)
...
...
@@ -45,7 +44,6 @@ def preprocess(image, w, h):
def
slerp
(
t
,
v0
,
v1
,
DOT_THRESHOLD
=
0.9995
):
if
not
isinstance
(
v0
,
np
.
ndarray
):
inputs_are_torch
=
True
input_device
=
v0
.
device
...
...
@@ -82,7 +80,6 @@ def set_requires_grad(model, value):
class
CLIPGuidedImagesMixingStableDiffusion
(
DiffusionPipeline
):
def
__init__
(
self
,
vae
:
AutoencoderKL
,
...
...
@@ -112,15 +109,14 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
self
.
feature_extractor_size
=
(
feature_extractor
.
size
if
isinstance
(
feature_extractor
.
size
,
int
)
else
feature_extractor
.
size
[
'
shortest_edge
'
]
else
feature_extractor
.
size
[
"
shortest_edge
"
]
)
self
.
normalize
=
transforms
.
Normalize
(
mean
=
feature_extractor
.
image_mean
,
std
=
feature_extractor
.
image_std
)
self
.
normalize
=
transforms
.
Normalize
(
mean
=
feature_extractor
.
image_mean
,
std
=
feature_extractor
.
image_std
)
set_requires_grad
(
self
.
text_encoder
,
False
)
set_requires_grad
(
self
.
clip_model
,
False
)
def
enable_attention_slicing
(
self
,
slice_size
:
Optional
[
Union
[
str
,
int
]]
=
'
auto
'
):
if
slice_size
==
'
auto
'
:
def
enable_attention_slicing
(
self
,
slice_size
:
Optional
[
Union
[
str
,
int
]]
=
"
auto
"
):
if
slice_size
==
"
auto
"
:
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size
=
self
.
unet
.
config
.
attention_head_dim
//
2
...
...
@@ -143,8 +139,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
def
get_timesteps
(
self
,
num_inference_steps
,
strength
,
device
):
# get the original timestep using init_timestep
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
...
...
@@ -153,15 +148,13 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
def
prepare_latents
(
self
,
image
,
timestep
,
batch_size
,
dtype
,
device
,
generator
=
None
):
if
not
isinstance
(
image
,
torch
.
Tensor
):
raise
ValueError
(
f
'`image` has to be of type `torch.Tensor` but is
{
type
(
image
)
}
'
)
raise
ValueError
(
f
"`image` has to be of type `torch.Tensor` but is
{
type
(
image
)
}
"
)
image
=
image
.
to
(
device
=
device
,
dtype
=
dtype
)
if
isinstance
(
generator
,
list
):
init_latents
=
[
self
.
vae
.
encode
(
image
[
i
:
i
+
1
]).
latent_dist
.
sample
(
generator
[
i
])
for
i
in
range
(
batch_size
)
self
.
vae
.
encode
(
image
[
i
:
i
+
1
]).
latent_dist
.
sample
(
generator
[
i
])
for
i
in
range
(
batch_size
)
]
init_latents
=
torch
.
cat
(
init_latents
,
dim
=
0
)
else
:
...
...
@@ -171,8 +164,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
init_latents
=
0.18215
*
init_latents
init_latents
=
init_latents
.
repeat_interleave
(
batch_size
,
dim
=
0
)
noise
=
randn_tensor
(
init_latents
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
noise
=
randn_tensor
(
init_latents
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
# get latents
init_latents
=
self
.
scheduler
.
add_noise
(
init_latents
,
noise
,
timestep
)
...
...
@@ -183,21 +175,16 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
def
get_image_description
(
self
,
image
):
transformed_image
=
self
.
coca_transform
(
image
).
unsqueeze
(
0
)
with
torch
.
no_grad
(),
torch
.
cuda
.
amp
.
autocast
():
generated
=
self
.
coca_model
.
generate
(
transformed_image
.
to
(
device
=
self
.
device
,
dtype
=
self
.
coca_model
.
dtype
))
generated
=
self
.
coca_model
.
generate
(
transformed_image
.
to
(
device
=
self
.
device
,
dtype
=
self
.
coca_model
.
dtype
))
generated
=
self
.
coca_tokenizer
.
decode
(
generated
[
0
].
cpu
().
numpy
())
return
generated
.
split
(
'
<end_of_text>
'
)[
0
].
replace
(
'
<start_of_text>
'
,
''
).
rstrip
(
'
.,
'
)
return
generated
.
split
(
"
<end_of_text>
"
)[
0
].
replace
(
"
<start_of_text>
"
,
""
).
rstrip
(
"
.,
"
)
def
get_clip_image_embeddings
(
self
,
image
,
batch_size
):
clip_image_input
=
self
.
feature_extractor
.
preprocess
(
image
)
clip_image_features
=
torch
.
from_numpy
(
clip_image_input
[
'pixel_values'
][
0
]).
unsqueeze
(
0
).
to
(
self
.
device
).
half
()
image_embeddings_clip
=
self
.
clip_model
.
get_image_features
(
clip_image_features
)
image_embeddings_clip
=
image_embeddings_clip
/
\
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
image_embeddings_clip
=
image_embeddings_clip
.
repeat_interleave
(
batch_size
,
dim
=
0
)
clip_image_features
=
torch
.
from_numpy
(
clip_image_input
[
"pixel_values"
][
0
]).
unsqueeze
(
0
).
to
(
self
.
device
).
half
()
image_embeddings_clip
=
self
.
clip_model
.
get_image_features
(
clip_image_features
)
image_embeddings_clip
=
image_embeddings_clip
/
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
image_embeddings_clip
=
image_embeddings_clip
.
repeat_interleave
(
batch_size
,
dim
=
0
)
return
image_embeddings_clip
@
torch
.
enable_grad
()
...
...
@@ -213,20 +200,17 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
):
latents
=
latents
.
detach
().
requires_grad_
()
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latents
,
timestep
)
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latents
,
timestep
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
timestep
,
encoder_hidden_states
=
text_embeddings
).
sample
noise_pred
=
self
.
unet
(
latent_model_input
,
timestep
,
encoder_hidden_states
=
text_embeddings
).
sample
if
isinstance
(
self
.
scheduler
,
(
PNDMScheduler
,
DDIMScheduler
,
DPMSolverMultistepScheduler
)):
alpha_prod_t
=
self
.
scheduler
.
alphas_cumprod
[
timestep
]
beta_prod_t
=
1
-
alpha_prod_t
# compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample
=
(
latents
-
beta_prod_t
**
(
0.5
)
*
noise_pred
)
/
alpha_prod_t
**
(
0.5
)
pred_original_sample
=
(
latents
-
beta_prod_t
**
(
0.5
)
*
noise_pred
)
/
alpha_prod_t
**
(
0.5
)
fac
=
torch
.
sqrt
(
beta_prod_t
)
sample
=
pred_original_sample
*
(
fac
)
+
latents
*
(
1
-
fac
)
...
...
@@ -234,8 +218,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
sigma
=
self
.
scheduler
.
sigmas
[
index
]
sample
=
latents
-
sigma
*
noise_pred
else
:
raise
ValueError
(
f
'scheduler type
{
type
(
self
.
scheduler
)
}
not supported'
)
raise
ValueError
(
f
"scheduler type
{
type
(
self
.
scheduler
)
}
not supported"
)
# Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
sample
=
1
/
0.18215
*
sample
...
...
@@ -246,11 +229,9 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
image
=
self
.
normalize
(
image
).
to
(
latents
.
dtype
)
image_embeddings_clip
=
self
.
clip_model
.
get_image_features
(
image
)
image_embeddings_clip
=
image_embeddings_clip
/
\
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
image_embeddings_clip
=
image_embeddings_clip
/
image_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
loss
=
spherical_dist_loss
(
image_embeddings_clip
,
original_image_embeddings_clip
).
mean
()
*
clip_guidance_scale
loss
=
spherical_dist_loss
(
image_embeddings_clip
,
original_image_embeddings_clip
).
mean
()
*
clip_guidance_scale
grads
=
-
torch
.
autograd
.
grad
(
loss
,
latents
)[
0
]
...
...
@@ -277,121 +258,101 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
eta
:
float
=
0.0
,
clip_guidance_scale
:
Optional
[
float
]
=
100
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
output_type
:
Optional
[
str
]
=
'
pil
'
,
output_type
:
Optional
[
str
]
=
"
pil
"
,
return_dict
:
bool
=
True
,
slerp_latent_style_strength
:
float
=
0.8
,
slerp_prompt_style_strength
:
float
=
0.1
,
slerp_clip_image_style_strength
:
float
=
0.1
,
):
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
raise
ValueError
(
f
'You have passed
{
batch_size
}
batch_size, but only
{
len
(
generator
)
}
generators.'
)
raise
ValueError
(
f
"You have passed
{
batch_size
}
batch_size, but only
{
len
(
generator
)
}
generators."
)
if
height
%
8
!=
0
or
width
%
8
!=
0
:
raise
ValueError
(
f
'`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
.'
)
raise
ValueError
(
f
"`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
."
)
if
isinstance
(
generator
,
torch
.
Generator
)
and
batch_size
>
1
:
generator
=
[
generator
]
+
[
None
]
*
(
batch_size
-
1
)
coca_is_none
=
[
(
'
model
'
,
self
.
coca_model
is
None
),
(
'
tokenizer
'
,
self
.
coca_tokenizer
is
None
),
(
'
transform
'
,
self
.
coca_transform
is
None
)
(
"
model
"
,
self
.
coca_model
is
None
),
(
"
tokenizer
"
,
self
.
coca_tokenizer
is
None
),
(
"
transform
"
,
self
.
coca_transform
is
None
)
,
]
coca_is_none
=
[
x
[
0
]
for
x
in
coca_is_none
if
x
[
1
]]
coca_is_none_str
=
'
,
'
.
join
(
coca_is_none
)
coca_is_none_str
=
"
,
"
.
join
(
coca_is_none
)
# generate prompts with coca model if prompt is None
if
content_prompt
is
None
:
if
len
(
coca_is_none
):
raise
ValueError
(
f
'
Content prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
'
f
'
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
'
f
"
Content prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
"
f
"
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
"
)
content_prompt
=
self
.
get_image_description
(
content_image
)
if
style_prompt
is
None
:
if
len
(
coca_is_none
):
raise
ValueError
(
f
'
Style prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
'
f
'
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
'
f
"
Style prompt is None and CoCa [
{
coca_is_none_str
}
] is None.
"
f
"
Set prompt or pass Coca [
{
coca_is_none_str
}
] to DiffusionPipeline.
"
)
style_prompt
=
self
.
get_image_description
(
style_image
)
# get prompt text embeddings for content and style
content_text_input
=
self
.
tokenizer
(
content_prompt
,
padding
=
'
max_length
'
,
padding
=
"
max_length
"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
'
pt
'
,
return_tensors
=
"
pt
"
,
)
content_text_embeddings
=
self
.
text_encoder
(
content_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
content_text_embeddings
=
self
.
text_encoder
(
content_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
style_text_input
=
self
.
tokenizer
(
style_prompt
,
padding
=
'
max_length
'
,
padding
=
"
max_length
"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
'
pt
'
,
return_tensors
=
"
pt
"
,
)
style_text_embeddings
=
self
.
text_encoder
(
style_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
style_text_embeddings
=
self
.
text_encoder
(
style_text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
text_embeddings
=
slerp
(
slerp_prompt_style_strength
,
content_text_embeddings
,
style_text_embeddings
)
text_embeddings
=
slerp
(
slerp_prompt_style_strength
,
content_text_embeddings
,
style_text_embeddings
)
# duplicate text embeddings for each generation per prompt
text_embeddings
=
text_embeddings
.
repeat_interleave
(
batch_size
,
dim
=
0
)
# set timesteps
accepts_offset
=
'offset'
in
set
(
inspect
.
signature
(
self
.
scheduler
.
set_timesteps
).
parameters
.
keys
())
accepts_offset
=
"offset"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
set_timesteps
).
parameters
.
keys
())
extra_set_kwargs
=
{}
if
accepts_offset
:
extra_set_kwargs
[
'
offset
'
]
=
1
extra_set_kwargs
[
"
offset
"
]
=
1
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
**
extra_set_kwargs
)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
self
.
scheduler
.
timesteps
.
to
(
self
.
device
)
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
num_inference_steps
,
noise_strength
,
self
.
device
)
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
num_inference_steps
,
noise_strength
,
self
.
device
)
latent_timestep
=
timesteps
[:
1
].
repeat
(
batch_size
)
# Preprocess image
preprocessed_content_image
=
preprocess
(
content_image
,
width
,
height
)
content_latents
=
self
.
prepare_latents
(
preprocessed_content_image
,
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
preprocessed_content_image
,
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
)
preprocessed_style_image
=
preprocess
(
style_image
,
width
,
height
)
style_latents
=
self
.
prepare_latents
(
preprocessed_style_image
,
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
preprocessed_style_image
,
latent_timestep
,
batch_size
,
text_embeddings
.
dtype
,
self
.
device
,
generator
)
latents
=
slerp
(
slerp_latent_style_strength
,
content_latents
,
style_latents
)
latents
=
slerp
(
slerp_latent_style_strength
,
content_latents
,
style_latents
)
if
clip_guidance_scale
>
0
:
content_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
content_image
,
batch_size
)
style_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
style_image
,
batch_size
)
content_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
content_image
,
batch_size
)
style_clip_image_embedding
=
self
.
get_clip_image_embeddings
(
style_image
,
batch_size
)
clip_image_embeddings
=
slerp
(
slerp_clip_image_style_strength
,
content_clip_image_embedding
,
style_clip_image_embedding
)
slerp_clip_image_style_strength
,
content_clip_image_embedding
,
style_clip_image_embedding
)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...
...
@@ -400,13 +361,10 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance
if
do_classifier_free_guidance
:
max_length
=
content_text_input
.
input_ids
.
shape
[
-
1
]
uncond_input
=
self
.
tokenizer
(
[
''
],
padding
=
'max_length'
,
max_length
=
max_length
,
return_tensors
=
'pt'
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
max_length
,
return_tensors
=
"pt"
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings
=
uncond_embeddings
.
repeat_interleave
(
batch_size
,
dim
=
0
)
uncond_embeddings
=
uncond_embeddings
.
repeat_interleave
(
batch_size
,
dim
=
0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
...
...
@@ -418,25 +376,19 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape
=
(
batch_size
,
self
.
unet
.
config
.
in_channels
,
height
//
8
,
width
//
8
)
latents_shape
=
(
batch_size
,
self
.
unet
.
config
.
in_channels
,
height
//
8
,
width
//
8
)
latents_dtype
=
text_embeddings
.
dtype
if
latents
is
None
:
if
self
.
device
.
type
==
'
mps
'
:
if
self
.
device
.
type
==
"
mps
"
:
# randn does not work reproducibly on mps
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
'cpu'
,
dtype
=
latents_dtype
).
to
(
self
.
device
)
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
latents_dtype
).
to
(
self
.
device
)
else
:
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
)
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
)
else
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
'Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
'
)
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
latents
=
latents
.
to
(
self
.
device
)
# scale the initial noise by the standard deviation required by the scheduler
...
...
@@ -446,41 +398,34 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta
=
'eta'
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
extra_step_kwargs
=
{}
if
accepts_eta
:
extra_step_kwargs
[
'
eta
'
]
=
eta
extra_step_kwargs
[
"
eta
"
]
=
eta
# check if the scheduler accepts generator
accepts_generator
=
'generator'
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
accepts_generator
=
"generator"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
if
accepts_generator
:
extra_step_kwargs
[
'
generator
'
]
=
generator
extra_step_kwargs
[
"
generator
"
]
=
generator
with
self
.
progress_bar
(
total
=
num_inference_steps
):
for
i
,
t
in
enumerate
(
timesteps
):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
(
[
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
# perform classifier free guidance
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
\
(
noise_pred_text
-
noise_pred_uncond
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# perform clip guidance
if
clip_guidance_scale
>
0
:
text_embeddings_for_guidance
=
(
text_embeddings
.
chunk
(
2
)[
1
]
if
do_classifier_free_guidance
else
text_embeddings
text_embeddings
.
chunk
(
2
)[
1
]
if
do_classifier_free_guidance
else
text_embeddings
)
noise_pred
,
latents
=
self
.
cond_fn
(
latents
,
...
...
@@ -493,8 +438,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
latents
=
1
/
0.18215
*
latents
...
...
@@ -503,7 +447,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
'
pil
'
:
if
output_type
==
"
pil
"
:
image
=
self
.
numpy_to_pil
(
image
)
if
not
return_dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment