Unverified Commit 9a45d7fb authored by Joachim Blaafjell Holwech's avatar Joachim Blaafjell Holwech Committed by GitHub
Browse files

Add guidance start/stop (#3770)



* Add guidance start/stop

* Add guidance start/stop to inpaint class

* Black formatting

* Add support for guidance for multicontrolnet

* Add inclusive end

* Improve design

* correct imports

* Finish

* Finish all

* Correct more

* make style

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 61916fef
...@@ -491,6 +491,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -491,6 +491,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
): ):
if (callback_steps is None) or ( if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
...@@ -593,6 +595,27 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -593,6 +595,27 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
else: else:
assert False assert False
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
)
if isinstance(self.controlnet, MultiControlNetModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
for start, end in zip(control_guidance_start, control_guidance_end):
if start >= end:
raise ValueError(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
)
if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
def check_image(self, image, prompt, prompt_embeds): def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image) image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor) image_is_tensor = isinstance(image, torch.Tensor)
...@@ -709,6 +732,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -709,6 +732,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -784,6 +809,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -784,6 +809,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
guess_mode (`bool`, *optional*, defaults to `False`): guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the controlnet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the controlnet stops applying.
Examples: Examples:
...@@ -794,6 +823,18 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -794,6 +823,18 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
control_guidance_end
]
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
...@@ -804,6 +845,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -804,6 +845,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
controlnet_conditioning_scale, controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
) )
# 2. Define call parameters # 2. Define call parameters
...@@ -820,8 +863,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -820,8 +863,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
...@@ -904,6 +945,15 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -904,6 +945,15 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(num_inference_steps):
keeps = [
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
...@@ -922,12 +972,17 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -922,12 +972,17 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
control_model_input = latent_model_input control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds controlnet_prompt_embeds = prompt_embeds
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
down_block_res_samples, mid_block_res_sample = self.controlnet( down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input, control_model_input,
t, t,
encoder_hidden_states=controlnet_prompt_embeds, encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image, controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale, conditioning_scale=cond_scale,
guess_mode=guess_mode, guess_mode=guess_mode,
return_dict=False, return_dict=False,
) )
......
...@@ -517,6 +517,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -517,6 +517,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
): ):
if (callback_steps is None) or ( if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
...@@ -619,6 +621,27 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -619,6 +621,27 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
else: else:
assert False assert False
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
)
if isinstance(self.controlnet, MultiControlNetModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
for start, end in zip(control_guidance_start, control_guidance_end):
if start >= end:
raise ValueError(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
)
if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
def check_image(self, image, prompt, prompt_embeds): def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image) image_is_pil = isinstance(image, PIL.Image.Image)
...@@ -796,6 +819,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -796,6 +819,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.8, controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -876,6 +901,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -876,6 +901,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
guess_mode (`bool`, *optional*, defaults to `False`): guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the controlnet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the controlnet stops applying.
Examples: Examples:
...@@ -886,6 +915,19 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -886,6 +915,19 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
control_guidance_end
]
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
...@@ -895,6 +937,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -895,6 +937,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
controlnet_conditioning_scale, controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
) )
# 2. Define call parameters # 2. Define call parameters
...@@ -994,6 +1038,15 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -994,6 +1038,15 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(num_inference_steps):
keeps = [
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
...@@ -1012,12 +1065,17 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -1012,12 +1065,17 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
control_model_input = latent_model_input control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds controlnet_prompt_embeds = prompt_embeds
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
down_block_res_samples, mid_block_res_sample = self.controlnet( down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input, control_model_input,
t, t,
encoder_hidden_states=controlnet_prompt_embeds, encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image, controlnet_cond=control_image,
conditioning_scale=controlnet_conditioning_scale, conditioning_scale=cond_scale,
guess_mode=guess_mode, guess_mode=guess_mode,
return_dict=False, return_dict=False,
) )
......
...@@ -646,6 +646,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -646,6 +646,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -751,6 +753,27 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -751,6 +753,27 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
else: else:
assert False assert False
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
)
if isinstance(self.controlnet, MultiControlNetModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
for start, end in zip(control_guidance_start, control_guidance_end):
if start >= end:
raise ValueError(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
)
if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
def check_image(self, image, prompt, prompt_embeds): def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image) image_is_pil = isinstance(image, PIL.Image.Image)
...@@ -990,6 +1013,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -990,6 +1013,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.5, controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -1073,6 +1098,10 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -1073,6 +1098,10 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
guess_mode (`bool`, *optional*, defaults to `False`): guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the controlnet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the controlnet stops applying.
Examples: Examples:
...@@ -1083,9 +1112,22 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -1083,9 +1112,22 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# 0. Default height and width to unet # 0. Default height and width to unet
height, width = self._default_height_width(height, width, image) height, width = self._default_height_width(height, width, image)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
control_guidance_end
]
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
...@@ -1097,6 +1139,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -1097,6 +1139,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
controlnet_conditioning_scale, controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
) )
# 2. Define call parameters # 2. Define call parameters
...@@ -1113,8 +1157,6 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -1113,8 +1157,6 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
...@@ -1231,6 +1273,15 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -1231,6 +1273,15 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(num_inference_steps):
keeps = [
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
...@@ -1249,12 +1300,17 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -1249,12 +1300,17 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
control_model_input = latent_model_input control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds controlnet_prompt_embeds = prompt_embeds
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
down_block_res_samples, mid_block_res_sample = self.controlnet( down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input, control_model_input,
t, t,
encoder_hidden_states=controlnet_prompt_embeds, encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image, controlnet_cond=control_image,
conditioning_scale=controlnet_conditioning_scale, conditioning_scale=cond_scale,
guess_mode=guess_mode, guess_mode=guess_mode,
return_dict=False, return_dict=False,
) )
......
...@@ -226,6 +226,12 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -226,6 +226,12 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32, cross_attention_dim=32,
) )
torch.manual_seed(0) torch.manual_seed(0)
def init_weights(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal(m.weight)
m.bias.data.fill_(1.0)
controlnet1 = ControlNetModel( controlnet1 = ControlNetModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
layers_per_block=2, layers_per_block=2,
...@@ -234,6 +240,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -234,6 +240,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32, cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32), conditioning_embedding_out_channels=(16, 32),
) )
controlnet1.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0) torch.manual_seed(0)
controlnet2 = ControlNetModel( controlnet2 = ControlNetModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
...@@ -243,6 +251,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -243,6 +251,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32, cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32), conditioning_embedding_out_channels=(16, 32),
) )
controlnet2.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0) torch.manual_seed(0)
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_start=0.00085, beta_start=0.00085,
...@@ -321,6 +331,39 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -321,6 +331,39 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
return inputs return inputs
def test_control_guidance_switch(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
scale = 10.0
steps = 4
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_1 = pipe(**inputs)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_3 = pipe(**inputs, control_guidance_start=[0.1, 0.3], control_guidance_end=[0.2, 0.7])[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5, 0.8])[0]
# make sure that all outputs are different
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
......
...@@ -180,6 +180,12 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -180,6 +180,12 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32, cross_attention_dim=32,
) )
torch.manual_seed(0) torch.manual_seed(0)
def init_weights(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal(m.weight)
m.bias.data.fill_(1.0)
controlnet1 = ControlNetModel( controlnet1 = ControlNetModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
layers_per_block=2, layers_per_block=2,
...@@ -188,6 +194,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -188,6 +194,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32, cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32), conditioning_embedding_out_channels=(16, 32),
) )
controlnet1.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0) torch.manual_seed(0)
controlnet2 = ControlNetModel( controlnet2 = ControlNetModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
...@@ -197,6 +205,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -197,6 +205,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32, cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32), conditioning_embedding_out_channels=(16, 32),
) )
controlnet2.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0) torch.manual_seed(0)
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_start=0.00085, beta_start=0.00085,
...@@ -279,6 +289,39 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -279,6 +289,39 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
return inputs return inputs
def test_control_guidance_switch(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
scale = 10.0
steps = 4
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_1 = pipe(**inputs)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_3 = pipe(**inputs, control_guidance_start=[0.1, 0.3], control_guidance_end=[0.2, 0.7])[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5, 0.8])[0]
# make sure that all outputs are different
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
......
...@@ -255,6 +255,12 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -255,6 +255,12 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
cross_attention_dim=32, cross_attention_dim=32,
) )
torch.manual_seed(0) torch.manual_seed(0)
def init_weights(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal(m.weight)
m.bias.data.fill_(1.0)
controlnet1 = ControlNetModel( controlnet1 = ControlNetModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
layers_per_block=2, layers_per_block=2,
...@@ -263,6 +269,8 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -263,6 +269,8 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
cross_attention_dim=32, cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32), conditioning_embedding_out_channels=(16, 32),
) )
controlnet1.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0) torch.manual_seed(0)
controlnet2 = ControlNetModel( controlnet2 = ControlNetModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
...@@ -272,6 +280,8 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -272,6 +280,8 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
cross_attention_dim=32, cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32), conditioning_embedding_out_channels=(16, 32),
) )
controlnet2.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0) torch.manual_seed(0)
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_start=0.00085, beta_start=0.00085,
...@@ -357,6 +367,39 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -357,6 +367,39 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
return inputs return inputs
def test_control_guidance_switch(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
scale = 10.0
steps = 4
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_1 = pipe(**inputs)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_3 = pipe(**inputs, control_guidance_start=[0.1, 0.3], control_guidance_end=[0.2, 0.7])[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5, 0.8])[0]
# make sure that all outputs are different
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment