Unverified Commit 9a45d7fb authored by Joachim Blaafjell Holwech's avatar Joachim Blaafjell Holwech Committed by GitHub
Browse files

Add guidance start/stop (#3770)



* Add guidance start/stop

* Add guidance start/stop to inpaint class

* Black formatting

* Add support for guidance for multicontrolnet

* Add inclusive end

* Improve design

* correct imports

* Finish

* Finish all

* Correct more

* make style

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 61916fef
......@@ -491,6 +491,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
......@@ -593,6 +595,27 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
else:
assert False
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
)
if isinstance(self.controlnet, MultiControlNetModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
for start, end in zip(control_guidance_start, control_guidance_end):
if start >= end:
raise ValueError(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
)
if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
......@@ -709,6 +732,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
......@@ -784,6 +809,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the controlnet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the controlnet stops applying.
Examples:
......@@ -794,6 +823,18 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
control_guidance_end
]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
......@@ -804,6 +845,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
)
# 2. Define call parameters
......@@ -820,8 +863,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
......@@ -904,6 +945,15 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(num_inference_steps):
keeps = [
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
......@@ -922,12 +972,17 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
......
......@@ -517,6 +517,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
......@@ -619,6 +621,27 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
else:
assert False
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
)
if isinstance(self.controlnet, MultiControlNetModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
for start, end in zip(control_guidance_start, control_guidance_end):
if start >= end:
raise ValueError(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
)
if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
......@@ -796,6 +819,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
......@@ -876,6 +901,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the controlnet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the controlnet stops applying.
Examples:
......@@ -886,6 +915,19 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
control_guidance_end
]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
......@@ -895,6 +937,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
)
# 2. Define call parameters
......@@ -994,6 +1038,15 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(num_inference_steps):
keeps = [
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
......@@ -1012,12 +1065,17 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image,
conditioning_scale=controlnet_conditioning_scale,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
......
......@@ -646,6 +646,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -751,6 +753,27 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
else:
assert False
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
)
if isinstance(self.controlnet, MultiControlNetModel):
if len(control_guidance_start) != len(self.controlnet.nets):
raise ValueError(
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
)
for start, end in zip(control_guidance_start, control_guidance_end):
if start >= end:
raise ValueError(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
)
if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
......@@ -990,6 +1013,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
......@@ -1073,6 +1098,10 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the controlnet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the controlnet stops applying.
Examples:
......@@ -1083,9 +1112,22 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
control_guidance_end
]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
......@@ -1097,6 +1139,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
)
# 2. Define call parameters
......@@ -1113,8 +1157,6 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
......@@ -1231,6 +1273,15 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(num_inference_steps):
keeps = [
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
......@@ -1249,12 +1300,17 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image,
conditioning_scale=controlnet_conditioning_scale,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
......
......@@ -226,6 +226,12 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32,
)
torch.manual_seed(0)
def init_weights(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal(m.weight)
m.bias.data.fill_(1.0)
controlnet1 = ControlNetModel(
block_out_channels=(32, 64),
layers_per_block=2,
......@@ -234,6 +240,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
)
controlnet1.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0)
controlnet2 = ControlNetModel(
block_out_channels=(32, 64),
......@@ -243,6 +251,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
)
controlnet2.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0)
scheduler = DDIMScheduler(
beta_start=0.00085,
......@@ -321,6 +331,39 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
return inputs
def test_control_guidance_switch(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
scale = 10.0
steps = 4
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_1 = pipe(**inputs)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_3 = pipe(**inputs, control_guidance_start=[0.1, 0.3], control_guidance_end=[0.2, 0.7])[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5, 0.8])[0]
# make sure that all outputs are different
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
......
......@@ -180,6 +180,12 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32,
)
torch.manual_seed(0)
def init_weights(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal(m.weight)
m.bias.data.fill_(1.0)
controlnet1 = ControlNetModel(
block_out_channels=(32, 64),
layers_per_block=2,
......@@ -188,6 +194,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
)
controlnet1.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0)
controlnet2 = ControlNetModel(
block_out_channels=(32, 64),
......@@ -197,6 +205,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
)
controlnet2.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0)
scheduler = DDIMScheduler(
beta_start=0.00085,
......@@ -279,6 +289,39 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
return inputs
def test_control_guidance_switch(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
scale = 10.0
steps = 4
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_1 = pipe(**inputs)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_3 = pipe(**inputs, control_guidance_start=[0.1, 0.3], control_guidance_end=[0.2, 0.7])[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5, 0.8])[0]
# make sure that all outputs are different
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
......
......@@ -255,6 +255,12 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
cross_attention_dim=32,
)
torch.manual_seed(0)
def init_weights(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.normal(m.weight)
m.bias.data.fill_(1.0)
controlnet1 = ControlNetModel(
block_out_channels=(32, 64),
layers_per_block=2,
......@@ -263,6 +269,8 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
)
controlnet1.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0)
controlnet2 = ControlNetModel(
block_out_channels=(32, 64),
......@@ -272,6 +280,8 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
)
controlnet2.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0)
scheduler = DDIMScheduler(
beta_start=0.00085,
......@@ -357,6 +367,39 @@ class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
return inputs
def test_control_guidance_switch(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
scale = 10.0
steps = 4
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_1 = pipe(**inputs)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_3 = pipe(**inputs, control_guidance_start=[0.1, 0.3], control_guidance_end=[0.2, 0.7])[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = steps
inputs["controlnet_conditioning_scale"] = scale
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5, 0.8])[0]
# make sure that all outputs are different
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment