Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
21682bab
Unverified
Commit
21682bab
authored
Aug 20, 2024
by
Disty0
Committed by
GitHub
Aug 20, 2024
Browse files
Custom sampler support for Stable Cascade Decoder (#9132)
Custom sampler support Stable Cascade Decoder
parent
214990e5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
5 deletions
+37
-5
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
...users/pipelines/stable_cascade/pipeline_stable_cascade.py
+35
-3
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
...pipelines/stable_cascade/pipeline_stable_cascade_prior.py
+2
-2
No files found.
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
View file @
21682bab
...
@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
...
@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
def
num_timesteps
(
self
):
def
num_timesteps
(
self
):
return
self
.
_num_timesteps
return
self
.
_num_timesteps
def
get_timestep_ratio_conditioning
(
self
,
t
,
alphas_cumprod
):
s
=
torch
.
tensor
([
0.008
])
clamp_range
=
[
0
,
1
]
min_var
=
torch
.
cos
(
s
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
var
=
alphas_cumprod
[
t
]
var
=
var
.
clamp
(
*
clamp_range
)
s
,
min_var
=
s
.
to
(
var
.
device
),
min_var
.
to
(
var
.
device
)
ratio
=
(((
var
*
min_var
)
**
0.5
).
acos
()
/
(
torch
.
pi
*
0.5
))
*
(
1
+
s
)
-
s
return
ratio
@
torch
.
no_grad
()
@
torch
.
no_grad
()
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
def
__call__
(
def
__call__
(
...
@@ -434,9 +444,29 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
...
@@ -434,9 +444,29 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
batch_size
,
image_embeddings
,
num_images_per_prompt
,
dtype
,
device
,
generator
,
latents
,
self
.
scheduler
batch_size
,
image_embeddings
,
num_images_per_prompt
,
dtype
,
device
,
generator
,
latents
,
self
.
scheduler
)
)
if
isinstance
(
self
.
scheduler
,
DDPMWuerstchenScheduler
):
timesteps
=
timesteps
[:
-
1
]
else
:
if
hasattr
(
self
.
scheduler
.
config
,
"clip_sample"
)
and
self
.
scheduler
.
config
.
clip_sample
:
self
.
scheduler
.
config
.
clip_sample
=
False
# disample sample clipping
logger
.
warning
(
" set `clip_sample` to be False"
)
# 6. Run denoising loop
# 6. Run denoising loop
self
.
_num_timesteps
=
len
(
timesteps
[:
-
1
])
if
hasattr
(
self
.
scheduler
,
"betas"
):
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
[:
-
1
])):
alphas
=
1.0
-
self
.
scheduler
.
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
dim
=
0
)
else
:
alphas_cumprod
=
[]
self
.
_num_timesteps
=
len
(
timesteps
)
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
)):
if
not
isinstance
(
self
.
scheduler
,
DDPMWuerstchenScheduler
):
if
len
(
alphas_cumprod
)
>
0
:
timestep_ratio
=
self
.
get_timestep_ratio_conditioning
(
t
.
long
().
cpu
(),
alphas_cumprod
)
timestep_ratio
=
timestep_ratio
.
expand
(
latents
.
size
(
0
)).
to
(
dtype
).
to
(
device
)
else
:
timestep_ratio
=
t
.
float
().
div
(
self
.
scheduler
.
timesteps
[
-
1
]).
expand
(
latents
.
size
(
0
)).
to
(
dtype
)
else
:
timestep_ratio
=
t
.
expand
(
latents
.
size
(
0
)).
to
(
dtype
)
timestep_ratio
=
t
.
expand
(
latents
.
size
(
0
)).
to
(
dtype
)
# 7. Denoise latents
# 7. Denoise latents
...
@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
...
@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
predicted_latents
=
torch
.
lerp
(
predicted_latents_uncond
,
predicted_latents_text
,
self
.
guidance_scale
)
predicted_latents
=
torch
.
lerp
(
predicted_latents_uncond
,
predicted_latents_text
,
self
.
guidance_scale
)
# 9. Renoise latents to next timestep
# 9. Renoise latents to next timestep
if
not
isinstance
(
self
.
scheduler
,
DDPMWuerstchenScheduler
):
timestep_ratio
=
t
latents
=
self
.
scheduler
.
step
(
latents
=
self
.
scheduler
.
step
(
model_output
=
predicted_latents
,
model_output
=
predicted_latents
,
timestep
=
timestep_ratio
,
timestep
=
timestep_ratio
,
...
...
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
View file @
21682bab
...
@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
...
@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
return
self
.
_num_timesteps
return
self
.
_num_timesteps
def
get_timestep_ratio_conditioning
(
self
,
t
,
alphas_cumprod
):
def
get_timestep_ratio_conditioning
(
self
,
t
,
alphas_cumprod
):
s
=
torch
.
tensor
([
0.00
3
])
s
=
torch
.
tensor
([
0.00
8
])
clamp_range
=
[
0
,
1
]
clamp_range
=
[
0
,
1
]
min_var
=
torch
.
cos
(
s
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
min_var
=
torch
.
cos
(
s
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
var
=
alphas_cumprod
[
t
]
var
=
alphas_cumprod
[
t
]
...
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
...
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
if
isinstance
(
self
.
scheduler
,
DDPMWuerstchenScheduler
):
if
isinstance
(
self
.
scheduler
,
DDPMWuerstchenScheduler
):
timesteps
=
timesteps
[:
-
1
]
timesteps
=
timesteps
[:
-
1
]
else
:
else
:
if
self
.
scheduler
.
config
.
clip_sample
:
if
hasattr
(
self
.
scheduler
.
config
,
"clip_sample"
)
and
self
.
scheduler
.
config
.
clip_sample
:
self
.
scheduler
.
config
.
clip_sample
=
False
# disample sample clipping
self
.
scheduler
.
config
.
clip_sample
=
False
# disample sample clipping
logger
.
warning
(
" set `clip_sample` to be False"
)
logger
.
warning
(
" set `clip_sample` to be False"
)
# 6. Run denoising loop
# 6. Run denoising loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment