Unverified Commit 58bcf46a authored by Hyowon Ha's avatar Hyowon Ha Committed by GitHub
Browse files

Add guidance start/end parameters to StableDiffusionControlNetImg2ImgPipeline (#2731)

* Add guidance start/end parameters to community controlnet img2img pipeline

* Fix formats
parent 0042efd0
...@@ -437,6 +437,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -437,6 +437,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
strength=None, strength=None,
controlnet_guidance_start=None,
controlnet_guidance_end=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -542,7 +544,23 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -542,7 +544,23 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
) )
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of `strength` should in [0.0, 1.0] but is {strength}")
if controlnet_guidance_start < 0 or controlnet_guidance_start > 1:
raise ValueError(
f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}"
)
if controlnet_guidance_end < 0 or controlnet_guidance_end > 1:
raise ValueError(
f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}"
)
if controlnet_guidance_start > controlnet_guidance_end:
raise ValueError(
"The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got"
f" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}"
)
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep # get the original timestep using init_timestep
...@@ -643,6 +661,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -643,6 +661,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0, controlnet_conditioning_scale: float = 1.0,
controlnet_guidance_start: float = 0.0,
controlnet_guidance_end: float = 1.0,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -719,6 +739,11 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -719,6 +739,11 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet. to the residual in the original unet.
controlnet_guidance_start ('float', *optional*, defaults to 0.0):
The percentage of total steps the controlnet starts applying. Must be between 0 and 1.
controlnet_guidance_end ('float', *optional*, defaults to 1.0):
The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater
than `controlnet_guidance_start`.
Examples: Examples:
...@@ -745,6 +770,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -745,6 +770,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
strength, strength,
controlnet_guidance_start,
controlnet_guidance_end,
) )
# 2. Define call parameters # 2. Define call parameters
...@@ -820,19 +847,31 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -820,19 +847,31 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
down_block_res_samples, mid_block_res_sample = self.controlnet( # compute the percentage of total steps we are at
latent_model_input, current_sampling_percent = i / len(timesteps)
t,
encoder_hidden_states=prompt_embeds, if (
controlnet_cond=controlnet_conditioning_image, current_sampling_percent < controlnet_guidance_start
return_dict=False, or current_sampling_percent > controlnet_guidance_end
) ):
# do not apply the controlnet
down_block_res_samples = [ down_block_res_samples = None
down_block_res_sample * controlnet_conditioning_scale mid_block_res_sample = None
for down_block_res_sample in down_block_res_samples else:
] # apply the controlnet
mid_block_res_sample *= controlnet_conditioning_scale down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=controlnet_conditioning_image,
return_dict=False,
)
down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment