Unverified Commit bf16a970 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix controlnet guess mode euler (#3571)



* Fix guess mode controlnet for euler-like schedulers

* make style

* Co-authored-by: Chanchana Sornsoontorn <off.chanchana@gmail.com>

* Add co author Co-authored-by: Chanchana Sornsoontorn <off.chanchana@gmail.com>

* 2nd try
Co-authored-by: default avatarChanchana Sornsoontorn <off.chanchana@gmail.com>
parent 66356e7d
......@@ -956,14 +956,15 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
controlnet_latent_model_input = latent_model_input
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
down_block_res_samples, mid_block_res_sample = self.controlnet(
controlnet_latent_model_input,
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
......
......@@ -1034,14 +1034,15 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
controlnet_latent_model_input = latent_model_input
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
down_block_res_samples, mid_block_res_sample = self.controlnet(
controlnet_latent_model_input,
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image,
......
......@@ -1248,16 +1248,18 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
controlnet_latent_model_input = latent_model_input
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
down_block_res_samples, mid_block_res_sample = self.controlnet(
controlnet_latent_model_input,
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image,
......
......@@ -26,6 +26,7 @@ from diffusers import (
AutoencoderKL,
ControlNetModel,
DDIMScheduler,
EulerDiscreteScheduler,
StableDiffusionControlNetPipeline,
UNet2DConditionModel,
)
......@@ -644,6 +645,39 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_canny_guess_mode_euler(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = ""
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
output = pipe(
prompt,
image,
generator=generator,
output_type="np",
num_inference_steps=3,
guidance_scale=3.0,
guess_mode=True,
)
image = output.images[0]
assert image.shape == (768, 512, 3)
image_slice = image[-3:, -3:, -1]
expected_slice = np.array([0.1655, 0.1721, 0.1623, 0.1685, 0.1711, 0.1646, 0.1651, 0.1631, 0.1494])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@require_torch_2
def test_stable_diffusion_compile(self):
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment