Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ab1b7b20
Unverified
Commit
ab1b7b20
authored
Oct 23, 2024
by
Álvaro Somoza
Committed by
GitHub
Oct 23, 2024
Browse files
[Official callbacks] SDXL Controlnet CFG Cutoff (#9311)
* initial proposal * style
parent
9366c8f8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
3 deletions
+58
-3
src/diffusers/callbacks.py
src/diffusers/callbacks.py
+56
-3
src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
...ffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+2
-0
No files found.
src/diffusers/callbacks.py
View file @
ab1b7b20
...
...
@@ -97,13 +97,17 @@ class SDCFGCutoffCallback(PipelineCallback):
class
SDXLCFGCutoffCallback
(
PipelineCallback
):
"""
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by
`cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.
Callback function for
the base
Stable Diffusion XL Pipelines. After certain number of steps (set by
`cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
tensor_inputs
=
[
"prompt_embeds"
,
"add_text_embeds"
,
"add_time_ids"
]
tensor_inputs
=
[
"prompt_embeds"
,
"add_text_embeds"
,
"add_time_ids"
,
]
def
callback_fn
(
self
,
pipeline
,
step_index
,
timestep
,
callback_kwargs
)
->
Dict
[
str
,
Any
]:
cutoff_step_ratio
=
self
.
config
.
cutoff_step_ratio
...
...
@@ -129,6 +133,55 @@ class SDXLCFGCutoffCallback(PipelineCallback):
callback_kwargs
[
self
.
tensor_inputs
[
0
]]
=
prompt_embeds
callback_kwargs
[
self
.
tensor_inputs
[
1
]]
=
add_text_embeds
callback_kwargs
[
self
.
tensor_inputs
[
2
]]
=
add_time_ids
return
callback_kwargs
class
SDXLControlnetCFGCutoffCallback
(
PipelineCallback
):
"""
Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
tensor_inputs
=
[
"prompt_embeds"
,
"add_text_embeds"
,
"add_time_ids"
,
"image"
,
]
def
callback_fn
(
self
,
pipeline
,
step_index
,
timestep
,
callback_kwargs
)
->
Dict
[
str
,
Any
]:
cutoff_step_ratio
=
self
.
config
.
cutoff_step_ratio
cutoff_step_index
=
self
.
config
.
cutoff_step_index
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step
=
(
cutoff_step_index
if
cutoff_step_index
is
not
None
else
int
(
pipeline
.
num_timesteps
*
cutoff_step_ratio
)
)
if
step_index
==
cutoff_step
:
prompt_embeds
=
callback_kwargs
[
self
.
tensor_inputs
[
0
]]
prompt_embeds
=
prompt_embeds
[
-
1
:]
# "-1" denotes the embeddings for conditional text tokens.
add_text_embeds
=
callback_kwargs
[
self
.
tensor_inputs
[
1
]]
add_text_embeds
=
add_text_embeds
[
-
1
:]
# "-1" denotes the embeddings for conditional pooled text tokens
add_time_ids
=
callback_kwargs
[
self
.
tensor_inputs
[
2
]]
add_time_ids
=
add_time_ids
[
-
1
:]
# "-1" denotes the embeddings for conditional added time vector
# For Controlnet
image
=
callback_kwargs
[
self
.
tensor_inputs
[
3
]]
image
=
image
[
-
1
:]
pipeline
.
_guidance_scale
=
0.0
callback_kwargs
[
self
.
tensor_inputs
[
0
]]
=
prompt_embeds
callback_kwargs
[
self
.
tensor_inputs
[
1
]]
=
add_text_embeds
callback_kwargs
[
self
.
tensor_inputs
[
2
]]
=
add_time_ids
callback_kwargs
[
self
.
tensor_inputs
[
3
]]
=
image
return
callback_kwargs
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
View file @
ab1b7b20
...
...
@@ -242,6 +242,7 @@ class StableDiffusionXLControlNetPipeline(
"add_time_ids"
,
"negative_pooled_prompt_embeds"
,
"negative_add_time_ids"
,
"image"
,
]
def
__init__
(
...
...
@@ -1540,6 +1541,7 @@ class StableDiffusionXLControlNetPipeline(
)
add_time_ids
=
callback_outputs
.
pop
(
"add_time_ids"
,
add_time_ids
)
negative_add_time_ids
=
callback_outputs
.
pop
(
"negative_add_time_ids"
,
negative_add_time_ids
)
image
=
callback_outputs
.
pop
(
"image"
,
image
)
# call the callback, if provided
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment