Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3383f774
Unverified
Commit
3383f774
authored
Oct 06, 2022
by
Suraj Patil
Committed by
GitHub
Oct 06, 2022
Browse files
update the clip guided PR according to the new API (#751)
parent
df9c0701
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
23 deletions
+27
-23
examples/community/clip_guided_stable_diffusion.py
examples/community/clip_guided_stable_diffusion.py
+27
-23
No files found.
examples/community/clip_guided_stable_diffusion.py
View file @
3383f774
...
...
@@ -175,6 +175,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
width
:
Optional
[
int
]
=
512
,
num_inference_steps
:
Optional
[
int
]
=
50
,
guidance_scale
:
Optional
[
float
]
=
7.5
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
clip_guidance_scale
:
Optional
[
float
]
=
100
,
clip_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
num_cutouts
:
Optional
[
int
]
=
4
,
...
...
@@ -203,6 +204,8 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
return_tensors
=
"pt"
,
)
text_embeddings
=
self
.
text_encoder
(
text_input
.
input_ids
.
to
(
self
.
device
))[
0
]
# duplicate text embeddings for each generation per prompt
text_embeddings
=
text_embeddings
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
if
clip_guidance_scale
>
0
:
if
clip_prompt
is
not
None
:
...
...
@@ -217,6 +220,8 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
clip_text_input
=
text_input
.
input_ids
.
to
(
self
.
device
)
text_embeddings_clip
=
self
.
clip_model
.
get_text_features
(
clip_text_input
)
text_embeddings_clip
=
text_embeddings_clip
/
text_embeddings_clip
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
# duplicate text embeddings clip for each generation per prompt
text_embeddings_clip
=
text_embeddings_clip
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...
...
@@ -225,10 +230,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance
if
do_classifier_free_guidance
:
max_length
=
text_input
.
input_ids
.
shape
[
-
1
]
uncond_input
=
self
.
tokenizer
(
[
""
]
*
batch_size
,
padding
=
"max_length"
,
max_length
=
max_length
,
return_tensors
=
"pt"
)
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
max_length
,
return_tensors
=
"pt"
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings
=
uncond_embeddings
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
...
...
@@ -240,14 +245,16 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_
device
=
"cpu"
if
self
.
device
.
type
==
"mps"
else
self
.
device
latents_
sha
pe
=
(
batch_size
,
self
.
unet
.
in_channels
,
height
//
8
,
width
//
8
)
latents_
shape
=
(
batch_size
*
num_images_per_prompt
,
self
.
unet
.
in_channels
,
height
//
8
,
width
//
8
)
latents_
dty
pe
=
text_embeddings
.
dtype
if
latents
is
None
:
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
latents_
device
,
if
self
.
device
.
type
==
"mps"
:
# randn does not exist on mps
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
latents_dtype
).
to
(
self
.
device
)
else
:
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
)
else
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
...
...
@@ -261,17 +268,17 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
**
extra_set_kwargs
)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if
isinstance
(
self
.
scheduler
,
LMSDiscreteScheduler
):
latents
=
latents
*
self
.
scheduler
.
sigmas
[
0
]
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor
=
self
.
scheduler
.
timesteps
.
to
(
self
.
device
)
# scale the initial noise by the standard deviation required by the scheduler
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
self
.
scheduler
.
timesteps
)):
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
_tensor
)):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
if
isinstance
(
self
.
scheduler
,
LMSDiscreteScheduler
):
sigma
=
self
.
scheduler
.
sigmas
[
i
]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input
=
latent_model_input
/
((
sigma
**
2
+
1
)
**
0.5
)
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
).
sample
...
...
@@ -299,9 +306,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
)
# compute the previous noisy sample x_t -> x_t-1
if
isinstance
(
self
.
scheduler
,
LMSDiscreteScheduler
):
latents
=
self
.
scheduler
.
step
(
noise_pred
,
i
,
latents
).
prev_sample
else
:
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
).
prev_sample
# scale and decode the image latents with vae
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment