Unverified Commit ab1b7b20 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[Official callbacks] SDXL Controlnet CFG Cutoff (#9311)

* initial proposal

* style
parent 9366c8f8
......@@ -97,13 +97,17 @@ class SDCFGCutoffCallback(PipelineCallback):
class SDXLCFGCutoffCallback(PipelineCallback):
"""
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.
Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
tensor_inputs = [
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
......@@ -129,6 +133,55 @@ class SDXLCFGCutoffCallback(PipelineCallback):
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
return callback_kwargs
class SDXLControlnetCFGCutoffCallback(PipelineCallback):
"""
Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
tensor_inputs = [
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
"image",
]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)
if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
add_time_ids = callback_kwargs[self.tensor_inputs[2]]
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
# For Controlnet
image = callback_kwargs[self.tensor_inputs[3]]
image = image[-1:]
pipeline._guidance_scale = 0.0
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
callback_kwargs[self.tensor_inputs[3]] = image
return callback_kwargs
......
......@@ -242,6 +242,7 @@ class StableDiffusionXLControlNetPipeline(
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
"image",
]
def __init__(
......@@ -1540,6 +1541,7 @@ class StableDiffusionXLControlNetPipeline(
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
image = callback_outputs.pop("image", image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment