Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
21682bab
Unverified
Commit
21682bab
authored
Aug 20, 2024
by
Disty0
Committed by
GitHub
Aug 20, 2024
Browse files
Custom sampler support for Stable Cascade Decoder (#9132)
Custom sampler support Stable Cascade Decoder
parent
214990e5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
5 deletions
+37
-5
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
...users/pipelines/stable_cascade/pipeline_stable_cascade.py
+35
-3
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
...pipelines/stable_cascade/pipeline_stable_cascade_prior.py
+2
-2
No files found.
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
View file @
21682bab
...
@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
...
@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
def
num_timesteps
(
self
):
def
num_timesteps
(
self
):
return
self
.
_num_timesteps
return
self
.
_num_timesteps
def
get_timestep_ratio_conditioning
(
self
,
t
,
alphas_cumprod
):
s
=
torch
.
tensor
([
0.008
])
clamp_range
=
[
0
,
1
]
min_var
=
torch
.
cos
(
s
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
var
=
alphas_cumprod
[
t
]
var
=
var
.
clamp
(
*
clamp_range
)
s
,
min_var
=
s
.
to
(
var
.
device
),
min_var
.
to
(
var
.
device
)
ratio
=
(((
var
*
min_var
)
**
0.5
).
acos
()
/
(
torch
.
pi
*
0.5
))
*
(
1
+
s
)
-
s
return
ratio
@
torch
.
no_grad
()
@
torch
.
no_grad
()
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
def
__call__
(
def
__call__
(
...
@@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
...
@@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
batch_size
,
image_embeddings
,
num_images_per_prompt
,
dtype
,
device
,
generator
,
latents
,
self
.
scheduler
batch_size
,
image_embeddings
,
num_images_per_prompt
,
dtype
,
device
,
generator
,
latents
,
self
.
scheduler
)
)
if
isinstance
(
self
.
scheduler
,
DDPMWuerstchenScheduler
):
timesteps
=
timesteps
[:
-
1
]
else
:
if
hasattr
(
self
.
scheduler
.
config
,
"clip_sample"
)
and
self
.
scheduler
.
config
.
clip_sample
:
self
.
scheduler
.
config
.
clip_sample
=
False
# disample sample clipping
logger
.
warning
(
" set `clip_sample` to be False"
)
# 6. Run denoising loop
# 6. Run denoising loop
self
.
_num_timesteps
=
len
(
timesteps
[:
-
1
])
if
hasattr
(
self
.
scheduler
,
"betas"
):
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
[:
-
1
])):
alphas
=
1.0
-
self
.
scheduler
.
betas
timestep_ratio
=
t
.
expand
(
latents
.
size
(
0
)).
to
(
dtype
)
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
dim
=
0
)
else
:
alphas_cumprod
=
[]
self
.
_num_timesteps
=
len
(
timesteps
)
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
)):
if
not
isinstance
(
self
.
scheduler
,
DDPMWuerstchenScheduler
):
if
len
(
alphas_cumprod
)
>
0
:
timestep_ratio
=
self
.
get_timestep_ratio_conditioning
(
t
.
long
().
cpu
(),
alphas_cumprod
)
timestep_ratio
=
timestep_ratio
.
expand
(
latents
.
size
(
0
)).
to
(
dtype
).
to
(
device
)
else
:
timestep_ratio
=
t
.
float
().
div
(
self
.
scheduler
.
timesteps
[
-
1
]).
expand
(
latents
.
size
(
0
)).
to
(
dtype
)
else
:
timestep_ratio
=
t
.
expand
(
latents
.
size
(
0
)).
to
(
dtype
)
# 7. Denoise latents
# 7. Denoise latents
predicted_latents
=
self
.
decoder
(
predicted_latents
=
self
.
decoder
(
...
@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
...
@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
predicted_latents
=
torch
.
lerp
(
predicted_latents_uncond
,
predicted_latents_text
,
self
.
guidance_scale
)
predicted_latents
=
torch
.
lerp
(
predicted_latents_uncond
,
predicted_latents_text
,
self
.
guidance_scale
)
# 9. Renoise latents to next timestep
# 9. Renoise latents to next timestep
if
not
isinstance
(
self
.
scheduler
,
DDPMWuerstchenScheduler
):
timestep_ratio
=
t
latents
=
self
.
scheduler
.
step
(
latents
=
self
.
scheduler
.
step
(
model_output
=
predicted_latents
,
model_output
=
predicted_latents
,
timestep
=
timestep_ratio
,
timestep
=
timestep_ratio
,
...
...
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
View file @
21682bab
...
@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
...
@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
return
self
.
_num_timesteps
return
self
.
_num_timesteps
def
get_timestep_ratio_conditioning
(
self
,
t
,
alphas_cumprod
):
def
get_timestep_ratio_conditioning
(
self
,
t
,
alphas_cumprod
):
s
=
torch
.
tensor
([
0.00
3
])
s
=
torch
.
tensor
([
0.00
8
])
clamp_range
=
[
0
,
1
]
clamp_range
=
[
0
,
1
]
min_var
=
torch
.
cos
(
s
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
min_var
=
torch
.
cos
(
s
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
var
=
alphas_cumprod
[
t
]
var
=
alphas_cumprod
[
t
]
...
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
...
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
if
isinstance
(
self
.
scheduler
,
DDPMWuerstchenScheduler
):
if
isinstance
(
self
.
scheduler
,
DDPMWuerstchenScheduler
):
timesteps
=
timesteps
[:
-
1
]
timesteps
=
timesteps
[:
-
1
]
else
:
else
:
if
self
.
scheduler
.
config
.
clip_sample
:
if
hasattr
(
self
.
scheduler
.
config
,
"clip_sample"
)
and
self
.
scheduler
.
config
.
clip_sample
:
self
.
scheduler
.
config
.
clip_sample
=
False
# disample sample clipping
self
.
scheduler
.
config
.
clip_sample
=
False
# disample sample clipping
logger
.
warning
(
" set `clip_sample` to be False"
)
logger
.
warning
(
" set `clip_sample` to be False"
)
# 6. Run denoising loop
# 6. Run denoising loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment