"vscode:/vscode.git/clone" did not exist on "9c7bf1bc586fea64e5729ea8b2cc4a68979f3ffe"
Unverified Commit 5c9dd0af authored by Takuma Mori's avatar Takuma Mori Committed by GitHub
Browse files

Add to support Guess Mode for StableDiffusionControlnetPipleline (#2998)

* add guess mode (WIP)

* fix uncond/cond order

* support guidance_scale=1.0 and batch != 1

* remove magic coeff

* add docstring

* add intergration test

* add document to controlnet.mdx

* made the comments a bit more explanatory

* fix table
parent d0f25820
...@@ -242,6 +242,42 @@ image.save("./multi_controlnet_output.png") ...@@ -242,6 +242,42 @@ image.save("./multi_controlnet_output.png")
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/multi_controlnet_output.png" width=600/> <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/multi_controlnet_output.png" width=600/>
### Guess Mode
Guess Mode is [a ControlNet feature that was implemented](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode) after the publication of [the paper](https://arxiv.org/abs/2302.05543). The description states:
>In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts.
#### The core implementation:
It adjusts the scale of the output residuals from ControlNet by a fixed ratio depending on the block depth. The shallowest DownBlock corresponds to `0.1`. As the blocks get deeper, the scale increases exponentially, and the scale for the output of the MidBlock becomes `1.0`.
Since the core implementation is just this, **it does not have any impact on prompt conditioning**. While it is common to use it without specifying any prompts, it is also possible to provide prompts if desired.
#### Usage:
Just specify `guess_mode=True` in the pipe() function. A `guidance_scale` between 3.0 and 5.0 is [recommended](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode).
```py
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet).to(
"cuda"
)
image = pipe("", image=canny_image, guess_mode=True, guidance_scale=3.0).images[0]
image.save("guess_mode_generated.png")
```
#### Output image comparison:
Canny Control Example
|no guess_mode with prompt|guess_mode without prompt|
|---|---|
|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0.png"><img width="128" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0_gm.png"><img width="128" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0_gm.png"/></a>|
## Available checkpoints ## Available checkpoints
ControlNet requires a *control image* in addition to the text-to-image *prompt*. ControlNet requires a *control image* in addition to the text-to-image *prompt*.
......
...@@ -456,6 +456,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -456,6 +456,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]: ) -> Union[ControlNetOutput, Tuple]:
# check channel order # check channel order
...@@ -556,6 +557,12 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -556,6 +557,12 @@ class ControlNetModel(ModelMixin, ConfigMixin):
mid_block_res_sample = self.controlnet_mid_block(sample) mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling # 6. scaling
if guess_mode:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
scales *= conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample *= scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale mid_block_res_sample *= conditioning_scale
......
...@@ -118,6 +118,7 @@ class MultiControlNetModel(ModelMixin): ...@@ -118,6 +118,7 @@ class MultiControlNetModel(ModelMixin):
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]: ) -> Union[ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
...@@ -131,6 +132,7 @@ class MultiControlNetModel(ModelMixin): ...@@ -131,6 +132,7 @@ class MultiControlNetModel(ModelMixin):
timestep_cond, timestep_cond,
attention_mask, attention_mask,
cross_attention_kwargs, cross_attention_kwargs,
guess_mode,
return_dict, return_dict,
) )
...@@ -627,7 +629,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -627,7 +629,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
) )
def prepare_image( def prepare_image(
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance,
guess_mode,
): ):
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
...@@ -664,7 +675,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -664,7 +675,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance: if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2) image = torch.cat([image] * 2)
return image return image
...@@ -747,6 +758,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -747,6 +758,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -819,6 +831,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -819,6 +831,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
corresponding scale as a list. corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
Examples: Examples:
Returns: Returns:
...@@ -883,6 +899,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -883,6 +899,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
device=device, device=device,
dtype=self.controlnet.dtype, dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
) )
elif isinstance(self.controlnet, MultiControlNetModel): elif isinstance(self.controlnet, MultiControlNetModel):
images = [] images = []
...@@ -897,6 +914,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -897,6 +914,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
device=device, device=device,
dtype=self.controlnet.dtype, dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
) )
images.append(image_) images.append(image_)
...@@ -934,15 +952,31 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -934,15 +952,31 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference # controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
controlnet_latent_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
down_block_res_samples, mid_block_res_sample = self.controlnet( down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input, controlnet_latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image, controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale, conditioning_scale=controlnet_conditioning_scale,
guess_mode=guess_mode,
return_dict=False, return_dict=False,
) )
if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, latent_model_input,
......
...@@ -553,6 +553,38 @@ class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase): ...@@ -553,6 +553,38 @@ class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase):
# make sure that less than 7 GB is allocated # make sure that less than 7 GB is allocated
assert mem_bytes < 4 * 10**9 assert mem_bytes < 4 * 10**9
def test_canny_guess_mode(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = ""
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
output = pipe(
prompt,
image,
generator=generator,
output_type="np",
num_inference_steps=3,
guidance_scale=3.0,
guess_mode=True,
)
image = output.images[0]
assert image.shape == (768, 512, 3)
image_slice = image[-3:, -3:, -1]
expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment