Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
cd77a036
Unverified
Commit
cd77a036
authored
Nov 09, 2022
by
Suraj Patil
Committed by
GitHub
Nov 09, 2022
Browse files
[CLIPGuidedStableDiffusion] support DDIM scheduler (#1190)
add ddim in clip guided
parent
663f0c19
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
4 deletions
+26
-4
examples/community/clip_guided_stable_diffusion.py
examples/community/clip_guided_stable_diffusion.py
+26
-4
No files found.
examples/community/clip_guided_stable_diffusion.py
View file @
cd77a036
...
...
@@ -5,7 +5,14 @@ import torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
diffusers
import
AutoencoderKL
,
DiffusionPipeline
,
LMSDiscreteScheduler
,
PNDMScheduler
,
UNet2DConditionModel
from
diffusers
import
(
AutoencoderKL
,
DDIMScheduler
,
DiffusionPipeline
,
LMSDiscreteScheduler
,
PNDMScheduler
,
UNet2DConditionModel
,
)
from
diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion
import
StableDiffusionPipelineOutput
from
torchvision
import
transforms
from
transformers
import
CLIPFeatureExtractor
,
CLIPModel
,
CLIPTextModel
,
CLIPTokenizer
...
...
@@ -56,7 +63,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
clip_model
:
CLIPModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
UNet2DConditionModel
,
scheduler
:
Union
[
PNDMScheduler
,
LMSDiscreteScheduler
],
scheduler
:
Union
[
PNDMScheduler
,
LMSDiscreteScheduler
,
DDIMScheduler
],
feature_extractor
:
CLIPFeatureExtractor
,
):
super
().
__init__
()
...
...
@@ -123,7 +130,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
timestep
,
encoder_hidden_states
=
text_embeddings
).
sample
if
isinstance
(
self
.
scheduler
,
PNDMScheduler
):
if
isinstance
(
self
.
scheduler
,
(
PNDMScheduler
,
DDIMScheduler
)
):
alpha_prod_t
=
self
.
scheduler
.
alphas_cumprod
[
timestep
]
beta_prod_t
=
1
-
alpha_prod_t
# compute predicted original sample from predicted noise also called
...
...
@@ -176,6 +183,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
num_inference_steps
:
Optional
[
int
]
=
50
,
guidance_scale
:
Optional
[
float
]
=
7.5
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
eta
:
float
=
0.0
,
clip_guidance_scale
:
Optional
[
float
]
=
100
,
clip_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
num_cutouts
:
Optional
[
int
]
=
4
,
...
...
@@ -275,6 +283,20 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# scale the initial noise by the standard deviation required by the scheduler
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
extra_step_kwargs
=
{}
if
accepts_eta
:
extra_step_kwargs
[
"eta"
]
=
eta
# check if the scheduler accepts generator
accepts_generator
=
"generator"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
if
accepts_generator
:
extra_step_kwargs
[
"generator"
]
=
generator
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps_tensor
)):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
...
...
@@ -306,7 +328,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
).
prev_sample
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# scale and decode the image latents with vae
latents
=
1
/
0.18215
*
latents
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment