Unverified Commit 81780882 authored by Aryan V S's avatar Aryan V S Committed by GitHub
Browse files

Addition of new callbacks to controlnets (#5812)



* add new callbacks to src/diffusers/pipelines/controlnet/pipeline_controlnet.py

* update callbacks

* fix repeated kwarg

* update

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent ebc7bede
...@@ -130,6 +130,7 @@ class StableDiffusionControlNetPipeline( ...@@ -130,6 +130,7 @@ class StableDiffusionControlNetPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -485,15 +486,21 @@ class StableDiffusionControlNetPipeline( ...@@ -485,15 +486,21 @@ class StableDiffusionControlNetPipeline(
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0, control_guidance_start=0.0,
control_guidance_end=1.0, control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
): ):
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -760,6 +767,10 @@ class StableDiffusionControlNetPipeline( ...@@ -760,6 +767,10 @@ class StableDiffusionControlNetPipeline(
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
...@@ -767,6 +778,14 @@ class StableDiffusionControlNetPipeline( ...@@ -767,6 +778,14 @@ class StableDiffusionControlNetPipeline(
def do_classifier_free_guidance(self): def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -786,14 +805,15 @@ class StableDiffusionControlNetPipeline( ...@@ -786,14 +805,15 @@ class StableDiffusionControlNetPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0, control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -868,6 +888,15 @@ class StableDiffusionControlNetPipeline( ...@@ -868,6 +888,15 @@ class StableDiffusionControlNetPipeline(
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -878,6 +907,23 @@ class StableDiffusionControlNetPipeline( ...@@ -878,6 +907,23 @@ class StableDiffusionControlNetPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance # align format for control guidance
...@@ -903,9 +949,12 @@ class StableDiffusionControlNetPipeline( ...@@ -903,9 +949,12 @@ class StableDiffusionControlNetPipeline(
controlnet_conditioning_scale, controlnet_conditioning_scale,
control_guidance_start, control_guidance_start,
control_guidance_end, control_guidance_end,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -929,7 +978,7 @@ class StableDiffusionControlNetPipeline( ...@@ -929,7 +978,7 @@ class StableDiffusionControlNetPipeline(
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
...@@ -940,7 +989,7 @@ class StableDiffusionControlNetPipeline( ...@@ -940,7 +989,7 @@ class StableDiffusionControlNetPipeline(
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
...@@ -988,6 +1037,7 @@ class StableDiffusionControlNetPipeline( ...@@ -988,6 +1037,7 @@ class StableDiffusionControlNetPipeline(
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
...@@ -1078,7 +1128,7 @@ class StableDiffusionControlNetPipeline( ...@@ -1078,7 +1128,7 @@ class StableDiffusionControlNetPipeline(
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond, timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
return_dict=False, return_dict=False,
...@@ -1087,11 +1137,21 @@ class StableDiffusionControlNetPipeline( ...@@ -1087,11 +1137,21 @@ class StableDiffusionControlNetPipeline(
# perform guidance # perform guidance
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -164,6 +164,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -164,6 +164,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -519,15 +520,21 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -519,15 +520,21 @@ class StableDiffusionControlNetImg2ImgPipeline(
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0, control_guidance_start=0.0,
control_guidance_end=1.0, control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
): ):
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -808,6 +815,29 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -808,6 +815,29 @@ class StableDiffusionControlNetImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -829,14 +859,15 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -829,14 +859,15 @@ class StableDiffusionControlNetImg2ImgPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.8, controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0, control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -892,12 +923,6 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -892,12 +923,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
...@@ -915,6 +940,15 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -915,6 +940,15 @@ class StableDiffusionControlNetImg2ImgPipeline(
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -925,6 +959,23 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -925,6 +959,23 @@ class StableDiffusionControlNetImg2ImgPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance # align format for control guidance
...@@ -950,8 +1001,13 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -950,8 +1001,13 @@ class StableDiffusionControlNetImg2ImgPipeline(
controlnet_conditioning_scale, controlnet_conditioning_scale,
control_guidance_start, control_guidance_start,
control_guidance_end, control_guidance_end,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -961,10 +1017,6 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -961,10 +1017,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
...@@ -978,23 +1030,23 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -978,23 +1030,23 @@ class StableDiffusionControlNetImg2ImgPipeline(
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare image # 4. Prepare image
...@@ -1010,7 +1062,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1010,7 +1062,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
elif isinstance(controlnet, MultiControlNetModel): elif isinstance(controlnet, MultiControlNetModel):
...@@ -1025,7 +1077,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1025,7 +1077,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
...@@ -1039,6 +1091,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1039,6 +1091,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( latents = self.prepare_latents(
...@@ -1068,11 +1121,11 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1068,11 +1121,11 @@ class StableDiffusionControlNetImg2ImgPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference # controlnet(s) inference
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch. # Infer ControlNet only for the conditional batch.
control_model_input = latents control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t) control_model_input = self.scheduler.scale_model_input(control_model_input, t)
...@@ -1099,7 +1152,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1099,7 +1152,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
return_dict=False, return_dict=False,
) )
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch. # Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches, # To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # add 0 to the unconditional batch to keep it unchanged.
...@@ -1111,20 +1164,30 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1111,20 +1164,30 @@ class StableDiffusionControlNetImg2ImgPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -286,6 +286,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -286,6 +286,7 @@ class StableDiffusionControlNetInpaintPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -656,18 +657,24 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -656,18 +657,24 @@ class StableDiffusionControlNetInpaintPipeline(
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0, control_guidance_start=0.0,
control_guidance_end=1.0, control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
): ):
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -999,6 +1006,29 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -999,6 +1006,29 @@ class StableDiffusionControlNetInpaintPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -1021,14 +1051,15 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1021,14 +1051,15 @@ class StableDiffusionControlNetInpaintPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.5, controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0, control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -1101,12 +1132,6 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1101,12 +1132,6 @@ class StableDiffusionControlNetInpaintPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
...@@ -1124,6 +1149,15 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1124,6 +1149,15 @@ class StableDiffusionControlNetInpaintPipeline(
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -1134,6 +1168,23 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1134,6 +1168,23 @@ class StableDiffusionControlNetInpaintPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance # align format for control guidance
...@@ -1161,8 +1212,13 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1161,8 +1212,13 @@ class StableDiffusionControlNetInpaintPipeline(
controlnet_conditioning_scale, controlnet_conditioning_scale,
control_guidance_start, control_guidance_start,
control_guidance_end, control_guidance_end,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -1172,10 +1228,6 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1172,10 +1228,6 @@ class StableDiffusionControlNetInpaintPipeline(
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
...@@ -1189,23 +1241,23 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1189,23 +1241,23 @@ class StableDiffusionControlNetInpaintPipeline(
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare image # 4. Prepare image
...@@ -1218,7 +1270,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1218,7 +1270,7 @@ class StableDiffusionControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
elif isinstance(controlnet, MultiControlNetModel): elif isinstance(controlnet, MultiControlNetModel):
...@@ -1233,7 +1285,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1233,7 +1285,7 @@ class StableDiffusionControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
...@@ -1261,6 +1313,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1261,6 +1313,7 @@ class StableDiffusionControlNetInpaintPipeline(
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0 is_strength_max = strength == 1.0
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels num_channels_latents = self.vae.config.latent_channels
...@@ -1297,7 +1350,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1297,7 +1350,7 @@ class StableDiffusionControlNetInpaintPipeline(
prompt_embeds.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
do_classifier_free_guidance, self.do_classifier_free_guidance,
) )
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
...@@ -1317,11 +1370,11 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1317,11 +1370,11 @@ class StableDiffusionControlNetInpaintPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference # controlnet(s) inference
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch. # Infer ControlNet only for the conditional batch.
control_model_input = latents control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t) control_model_input = self.scheduler.scale_model_input(control_model_input, t)
...@@ -1348,7 +1401,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1348,7 +1401,7 @@ class StableDiffusionControlNetInpaintPipeline(
return_dict=False, return_dict=False,
) )
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch. # Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches, # To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # add 0 to the unconditional batch to keep it unchanged.
...@@ -1363,14 +1416,14 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1363,14 +1416,14 @@ class StableDiffusionControlNetInpaintPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
...@@ -1379,7 +1432,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1379,7 +1432,7 @@ class StableDiffusionControlNetInpaintPipeline(
if num_channels_unet == 4: if num_channels_unet == 4:
init_latents_proper = image_latents init_latents_proper = image_latents
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
init_mask, _ = mask.chunk(2) init_mask, _ = mask.chunk(2)
else: else:
init_mask = mask init_mask = mask
...@@ -1392,6 +1445,16 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1392,6 +1445,16 @@ class StableDiffusionControlNetInpaintPipeline(
latents = (1 - init_mask) * init_latents_proper + init_mask * latents latents = (1 - init_mask) * init_latents_proper + init_mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -34,6 +34,7 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -34,6 +34,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -167,6 +168,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -167,6 +168,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -555,6 +557,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -555,6 +557,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0, control_guidance_start=0.0,
control_guidance_end=1.0, control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -565,14 +568,20 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -565,14 +568,20 @@ class StableDiffusionXLControlNetInpaintPipeline(
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
f" {type(num_inference_steps)}." f" {type(num_inference_steps)}."
) )
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -1008,6 +1017,29 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1008,6 +1017,29 @@ class StableDiffusionXLControlNetInpaintPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -1039,8 +1071,6 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1039,8 +1071,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False, guess_mode: bool = False,
...@@ -1053,6 +1083,9 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1053,6 +1083,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
aesthetic_score: float = 6.0, aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5, negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -1147,12 +1180,6 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1147,12 +1180,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
...@@ -1182,6 +1209,15 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1182,6 +1209,15 @@ class StableDiffusionXLControlNetInpaintPipeline(
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -1190,6 +1226,23 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1190,6 +1226,23 @@ class StableDiffusionXLControlNetInpaintPipeline(
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. `tuple. When returning a tuple, the first element is a list with the generated images. `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance # align format for control guidance
...@@ -1237,8 +1290,13 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1237,8 +1290,13 @@ class StableDiffusionXLControlNetInpaintPipeline(
controlnet_conditioning_scale, controlnet_conditioning_scale,
control_guidance_start, control_guidance_start,
control_guidance_end, control_guidance_end,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -1248,17 +1306,13 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1248,17 +1306,13 @@ class StableDiffusionXLControlNetInpaintPipeline(
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
( (
...@@ -1271,7 +1325,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1271,7 +1325,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
prompt_2=prompt_2, prompt_2=prompt_2,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2, negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
...@@ -1279,7 +1333,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1279,7 +1333,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# 4. set timesteps # 4. set timesteps
...@@ -1300,6 +1354,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1300,6 +1354,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0 is_strength_max = strength == 1.0
self._num_timesteps = len(timesteps)
# 5. Preprocess mask and image - resizes image and mask w.r.t height and width # 5. Preprocess mask and image - resizes image and mask w.r.t height and width
# 5.1 Prepare init image # 5.1 Prepare init image
...@@ -1316,7 +1371,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1316,7 +1371,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
elif isinstance(controlnet, MultiControlNetModel): elif isinstance(controlnet, MultiControlNetModel):
...@@ -1331,7 +1386,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1331,7 +1386,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
...@@ -1385,7 +1440,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1385,7 +1440,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
prompt_embeds.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
do_classifier_free_guidance, self.do_classifier_free_guidance,
) )
# 8. Check that sizes of mask, masked image and latents match # 8. Check that sizes of mask, masked image and latents match
...@@ -1446,7 +1501,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1446,7 +1501,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
) )
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
...@@ -1483,7 +1538,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1483,7 +1538,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension # concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -1491,7 +1546,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1491,7 +1546,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# controlnet(s) inference # controlnet(s) inference
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch. # Infer ControlNet only for the conditional batch.
control_model_input = latents control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t) control_model_input = self.scheduler.scale_model_input(control_model_input, t)
...@@ -1528,7 +1583,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1528,7 +1583,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
return_dict=False, return_dict=False,
) )
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch. # Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches, # To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # add 0 to the unconditional batch to keep it unchanged.
...@@ -1543,7 +1598,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1543,7 +1598,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
...@@ -1551,11 +1606,11 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1551,11 +1606,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0: if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
...@@ -1564,7 +1619,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1564,7 +1619,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
if num_channels_unet == 4: if num_channels_unet == 4:
init_latents_proper = image_latents init_latents_proper = image_latents
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
init_mask, _ = mask.chunk(2) init_mask, _ = mask.chunk(2)
else: else:
init_mask = mask init_mask = mask
...@@ -1577,6 +1632,16 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1577,6 +1632,16 @@ class StableDiffusionXLControlNetInpaintPipeline(
latents = (1 - init_mask) * init_latents_proper + init_mask * latents latents = (1 - init_mask) * init_latents_proper + init_mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -35,7 +35,14 @@ from ...models.attention_processor import ( ...@@ -35,7 +35,14 @@ from ...models.attention_processor import (
) )
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
...@@ -143,6 +150,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -143,6 +150,7 @@ class StableDiffusionXLControlNetPipeline(
# leave controlnet out on purpose because it iterates with unet # leave controlnet out on purpose because it iterates with unet
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -487,15 +495,21 @@ class StableDiffusionXLControlNetPipeline( ...@@ -487,15 +495,21 @@ class StableDiffusionXLControlNetPipeline(
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0, control_guidance_start=0.0,
control_guidance_end=1.0, control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
): ):
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -825,6 +839,10 @@ class StableDiffusionXLControlNetPipeline( ...@@ -825,6 +839,10 @@ class StableDiffusionXLControlNetPipeline(
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
...@@ -832,6 +850,14 @@ class StableDiffusionXLControlNetPipeline( ...@@ -832,6 +850,14 @@ class StableDiffusionXLControlNetPipeline(
def do_classifier_free_guidance(self): def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -855,8 +881,6 @@ class StableDiffusionXLControlNetPipeline( ...@@ -855,8 +881,6 @@ class StableDiffusionXLControlNetPipeline(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False, guess_mode: bool = False,
...@@ -869,6 +893,9 @@ class StableDiffusionXLControlNetPipeline( ...@@ -869,6 +893,9 @@ class StableDiffusionXLControlNetPipeline(
negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None, negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -937,12 +964,6 @@ class StableDiffusionXLControlNetPipeline( ...@@ -937,12 +964,6 @@ class StableDiffusionXLControlNetPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
...@@ -989,6 +1010,15 @@ class StableDiffusionXLControlNetPipeline( ...@@ -989,6 +1010,15 @@ class StableDiffusionXLControlNetPipeline(
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -997,6 +1027,23 @@ class StableDiffusionXLControlNetPipeline( ...@@ -997,6 +1027,23 @@ class StableDiffusionXLControlNetPipeline(
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned containing the output images. otherwise a `tuple` is returned containing the output images.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance # align format for control guidance
...@@ -1026,9 +1073,12 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1026,9 +1073,12 @@ class StableDiffusionXLControlNetPipeline(
controlnet_conditioning_scale, controlnet_conditioning_scale,
control_guidance_start, control_guidance_start,
control_guidance_end, control_guidance_end,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -1052,7 +1102,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1052,7 +1102,7 @@ class StableDiffusionXLControlNetPipeline(
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
( (
prompt_embeds, prompt_embeds,
...@@ -1072,7 +1122,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1072,7 +1122,7 @@ class StableDiffusionXLControlNetPipeline(
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# 4. Prepare image # 4. Prepare image
...@@ -1115,6 +1165,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1115,6 +1165,7 @@ class StableDiffusionXLControlNetPipeline(
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
...@@ -1254,7 +1305,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1254,7 +1305,7 @@ class StableDiffusionXLControlNetPipeline(
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond, timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
...@@ -1269,6 +1320,16 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1269,6 +1320,16 @@ class StableDiffusionXLControlNetPipeline(
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -37,6 +37,7 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -37,6 +37,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -195,6 +196,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -195,6 +196,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -543,6 +545,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -543,6 +545,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
control_guidance_start=0.0, control_guidance_start=0.0,
control_guidance_end=1.0, control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -553,14 +556,20 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -553,14 +556,20 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
f" {type(num_inference_steps)}." f" {type(num_inference_steps)}."
) )
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -951,6 +960,29 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -951,6 +960,29 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -976,8 +1008,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -976,8 +1008,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.8, controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
guess_mode: bool = False, guess_mode: bool = False,
...@@ -992,6 +1022,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -992,6 +1022,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
aesthetic_score: float = 6.0, aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5, negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -1077,12 +1110,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1077,12 +1110,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
...@@ -1138,6 +1165,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1138,6 +1165,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -1146,6 +1182,23 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1146,6 +1182,23 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
containing the output images. containing the output images.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance # align format for control guidance
...@@ -1177,8 +1230,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1177,8 +1230,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
controlnet_conditioning_scale, controlnet_conditioning_scale,
control_guidance_start, control_guidance_start,
control_guidance_end, control_guidance_end,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -1188,10 +1246,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1188,10 +1246,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
...@@ -1205,7 +1259,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1205,7 +1259,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
( (
prompt_embeds, prompt_embeds,
...@@ -1217,7 +1271,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1217,7 +1271,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
prompt_2, prompt_2,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
negative_prompt_2, negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
...@@ -1225,7 +1279,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1225,7 +1279,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# 4. Prepare image and controlnet_conditioning_image # 4. Prepare image and controlnet_conditioning_image
...@@ -1240,7 +1294,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1240,7 +1294,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
height, width = control_image.shape[-2:] height, width = control_image.shape[-2:]
...@@ -1256,7 +1310,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1256,7 +1310,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
dtype=controlnet.dtype, dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
...@@ -1271,6 +1325,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1271,6 +1325,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( latents = self.prepare_latents(
...@@ -1328,7 +1383,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1328,7 +1383,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
) )
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
...@@ -1343,13 +1398,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1343,13 +1398,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# controlnet(s) inference # controlnet(s) inference
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch. # Infer ControlNet only for the conditional batch.
control_model_input = latents control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t) control_model_input = self.scheduler.scale_model_input(control_model_input, t)
...@@ -1382,7 +1437,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1382,7 +1437,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
return_dict=False, return_dict=False,
) )
if guess_mode and do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch. # Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches, # To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # add 0 to the unconditional batch to keep it unchanged.
...@@ -1394,7 +1449,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1394,7 +1449,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
...@@ -1402,13 +1457,23 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1402,13 +1457,23 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment