Unverified Commit 2b23ec82 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add callbacks to denoising step (#5427)



* draft1

* update

* style

* move to the end of loop

* update

* update callbak_on_step_end_inputs

* Revert "update"

This reverts commit 5f9b153183d0cde3b850f14024d2e37ae8c19576.

* Revert "update callbak_on_step_end_inputs"

This reverts commit 44889f4dabad95b7ebb330faa5f1955b5d008c88.

* update

* update test required_optional_params

* remove self.lora_scale

* img2img

* inpaint

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix

* apply feedbacks on img2img + inpaint: keep only important pipeline attributes

* depth

* pix2pix

* make _callback_tensor_inputs an class variable so that we can use it for testing

* add a basic tst for callback

* add a read-only tensor input timesteps + fix tests

* add second test for callback cfg

* sdxl

* sdxl img2img

* sdxl inpaint

* kandinsky prior

* kandinsky decoder

* kandinsky img2img + combined

* kandinsky inpaint

* fix copies

* fix

* consistent default inputs

* fix copies

* wuerstchen_prior prior

* test_wuerstchen_decoder + fix test for prior

* wuerstchen_combined pipeline + skip tests

* skip test for kandinsky combined

* lcm

* remove timesteps etc

* add doc string

* copies

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* make style and improve tests

* up

* up

* fix more

* fix cfg test

* tests for callbacks

* fix for real

* update

* lcm img2img

* add doc

* add doc page to index

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 080081bd
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Callable, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -92,6 +92,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -92,6 +92,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"]
def __init__( def __init__(
self, self,
...@@ -152,8 +153,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -152,8 +153,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_steps: int = 1, callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -201,12 +203,15 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -201,12 +203,15 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*): callback_on_step_end (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the A function that calls at the end of each denoising steps during the inference. The function is called
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_steps (`int`, *optional*, defaults to 1): callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
The frequency at which the `callback` function is called. If not specified, the callback is called at `callback_on_step_end_tensor_inputs`.
every step. callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -244,8 +249,34 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -244,8 +249,34 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 0. Check inputs # 0. Check inputs
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) self.check_inputs(
prompt,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._image_guidance_scale = image_guidance_scale
if image is None: if image is None:
raise ValueError("`image` input cannot be undefined.") raise ValueError("`image` input cannot be undefined.")
...@@ -259,10 +290,6 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -259,10 +290,6 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0
# check if scheduler is in sigmas space # check if scheduler is in sigmas space
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
...@@ -271,7 +298,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -271,7 +298,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
...@@ -291,7 +318,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -291,7 +318,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
num_images_per_prompt, num_images_per_prompt,
prompt_embeds.dtype, prompt_embeds.dtype,
device, device,
do_classifier_free_guidance, self.do_classifier_free_guidance,
generator, generator,
) )
...@@ -328,12 +355,13 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -328,12 +355,13 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
# 9. Denoising loop # 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# Expand the latents if we are doing classifier free guidance. # Expand the latents if we are doing classifier free guidance.
# The latents are expanded 3 times because for pix2pix the guidance\ # The latents are expanded 3 times because for pix2pix the guidance\
# is applied for both the text and the input image. # is applied for both the text and the input image.
latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
# concat latents, image_latents in the channel dimension # concat latents, image_latents in the channel dimension
scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -354,12 +382,12 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -354,12 +382,12 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
noise_pred = latent_model_input - sigma * noise_pred noise_pred = latent_model_input - sigma * noise_pred
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
noise_pred = ( noise_pred = (
noise_pred_uncond noise_pred_uncond
+ guidance_scale * (noise_pred_text - noise_pred_image) + self.guidance_scale * (noise_pred_text - noise_pred_image)
+ image_guidance_scale * (noise_pred_image - noise_pred_uncond) + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
) )
# Hack: # Hack:
...@@ -374,6 +402,17 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -374,6 +402,17 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
image_latents = callback_outputs.pop("image_latents", image_latents)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
...@@ -596,16 +635,27 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -596,16 +635,27 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
return image return image
def check_inputs( def check_inputs(
self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self,
): prompt,
if (callback_steps is None) or ( callback_steps,
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -728,3 +778,22 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -728,3 +778,22 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
def disable_freeu(self): def disable_freeu(self):
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def image_guidance_scale(self):
return self._image_guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0
...@@ -382,17 +382,22 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -382,17 +382,22 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -452,17 +452,22 @@ class StableDiffusionLDM3DPipeline( ...@@ -452,17 +452,22 @@ class StableDiffusionLDM3DPipeline(
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -433,17 +433,22 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -433,17 +433,22 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -430,17 +430,22 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -430,17 +430,22 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -436,17 +436,22 @@ class StableDiffusionParadigmsPipeline( ...@@ -436,17 +436,22 @@ class StableDiffusionParadigmsPipeline(
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -440,17 +440,22 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -440,17 +440,22 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -364,17 +364,22 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -364,17 +364,22 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -35,6 +35,7 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -35,6 +35,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate,
is_invisible_watermark_available, is_invisible_watermark_available,
is_torch_xla_available, is_torch_xla_available,
logging, logging,
...@@ -141,6 +142,15 @@ class StableDiffusionXLPipeline( ...@@ -141,6 +142,15 @@ class StableDiffusionXLPipeline(
""" """
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
]
def __init__( def __init__(
self, self,
...@@ -476,18 +486,24 @@ class StableDiffusionXLPipeline( ...@@ -476,18 +486,24 @@ class StableDiffusionXLPipeline(
negative_prompt_embeds=None, negative_prompt_embeds=None,
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -620,6 +636,37 @@ class StableDiffusionXLPipeline( ...@@ -620,6 +636,37 @@ class StableDiffusionXLPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def denoising_end(self):
return self._denoising_end
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -643,8 +690,6 @@ class StableDiffusionXLPipeline( ...@@ -643,8 +690,6 @@ class StableDiffusionXLPipeline(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None, original_size: Optional[Tuple[int, int]] = None,
...@@ -654,6 +699,9 @@ class StableDiffusionXLPipeline( ...@@ -654,6 +699,9 @@ class StableDiffusionXLPipeline(
negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None, negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -730,12 +778,6 @@ class StableDiffusionXLPipeline( ...@@ -730,12 +778,6 @@ class StableDiffusionXLPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple. of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
...@@ -774,6 +816,15 @@ class StableDiffusionXLPipeline( ...@@ -774,6 +816,15 @@ class StableDiffusionXLPipeline(
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -782,6 +833,23 @@ class StableDiffusionXLPipeline( ...@@ -782,6 +833,23 @@ class StableDiffusionXLPipeline(
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images. `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor
...@@ -802,8 +870,15 @@ class StableDiffusionXLPipeline( ...@@ -802,8 +870,15 @@ class StableDiffusionXLPipeline(
negative_prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -814,13 +889,10 @@ class StableDiffusionXLPipeline( ...@@ -814,13 +889,10 @@ class StableDiffusionXLPipeline(
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
( (
prompt_embeds, prompt_embeds,
...@@ -832,7 +904,7 @@ class StableDiffusionXLPipeline( ...@@ -832,7 +904,7 @@ class StableDiffusionXLPipeline(
prompt_2=prompt_2, prompt_2=prompt_2,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2, negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
...@@ -840,7 +912,7 @@ class StableDiffusionXLPipeline( ...@@ -840,7 +912,7 @@ class StableDiffusionXLPipeline(
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=lora_scale, lora_scale=lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# 4. Prepare timesteps # 4. Prepare timesteps
...@@ -889,7 +961,7 @@ class StableDiffusionXLPipeline( ...@@ -889,7 +961,7 @@ class StableDiffusionXLPipeline(
else: else:
negative_add_time_ids = add_time_ids negative_add_time_ids = add_time_ids
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
...@@ -902,20 +974,26 @@ class StableDiffusionXLPipeline( ...@@ -902,20 +974,26 @@ class StableDiffusionXLPipeline(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 8.1 Apply denoising_end # 8.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: if (
self.denoising_end is not None
and isinstance(self.denoising_end, float)
and self.denoising_end > 0
and self.denoising_end < 1
):
discrete_timestep_cutoff = int( discrete_timestep_cutoff = int(
round( round(
self.scheduler.config.num_train_timesteps self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps) - (self.denoising_end * self.scheduler.config.num_train_timesteps)
) )
) )
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps] timesteps = timesteps[:num_inference_steps]
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -925,23 +1003,39 @@ class StableDiffusionXLPipeline( ...@@ -925,23 +1003,39 @@ class StableDiffusionXLPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0: if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -32,6 +32,7 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -32,6 +32,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate,
is_invisible_watermark_available, is_invisible_watermark_available,
is_torch_xla_available, is_torch_xla_available,
logging, logging,
...@@ -154,6 +155,15 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -154,6 +155,15 @@ class StableDiffusionXLImg2ImgPipeline(
""" """
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
]
def __init__( def __init__(
self, self,
...@@ -488,6 +498,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -488,6 +498,7 @@ class StableDiffusionXLImg2ImgPipeline(
negative_prompt_2=None, negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -498,14 +509,19 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -498,14 +509,19 @@ class StableDiffusionXLImg2ImgPipeline(
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
f" {type(num_inference_steps)}." f" {type(num_inference_steps)}."
) )
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -747,6 +763,41 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -747,6 +763,41 @@ class StableDiffusionXLImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def denoising_end(self):
return self._denoising_end
@property
def denoising_start(self):
return self._denoising_start
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -771,8 +822,6 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -771,8 +822,6 @@ class StableDiffusionXLImg2ImgPipeline(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None, original_size: Tuple[int, int] = None,
...@@ -784,6 +833,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -784,6 +833,9 @@ class StableDiffusionXLImg2ImgPipeline(
aesthetic_score: float = 6.0, aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5, negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -867,12 +919,6 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -867,12 +919,6 @@ class StableDiffusionXLImg2ImgPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
...@@ -922,6 +968,15 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -922,6 +968,15 @@ class StableDiffusionXLImg2ImgPipeline(
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -930,6 +985,23 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -930,6 +985,23 @@ class StableDiffusionXLImg2ImgPipeline(
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images. `tuple. When returning a tuple, the first element is a list with the generated images.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
...@@ -941,8 +1013,16 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -941,8 +1013,16 @@ class StableDiffusionXLImg2ImgPipeline(
negative_prompt_2, negative_prompt_2,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -953,14 +1033,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -953,14 +1033,9 @@ class StableDiffusionXLImg2ImgPipeline(
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
( (
prompt_embeds, prompt_embeds,
...@@ -972,7 +1047,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -972,7 +1047,7 @@ class StableDiffusionXLImg2ImgPipeline(
prompt_2=prompt_2, prompt_2=prompt_2,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2, negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
...@@ -980,7 +1055,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -980,7 +1055,7 @@ class StableDiffusionXLImg2ImgPipeline(
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# 4. Preprocess image # 4. Preprocess image
...@@ -988,15 +1063,18 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -988,15 +1063,18 @@ class StableDiffusionXLImg2ImgPipeline(
# 5. Prepare timesteps # 5. Prepare timesteps
def denoising_value_valid(dnv): def denoising_value_valid(dnv):
return isinstance(denoising_end, float) and 0 < dnv < 1 return isinstance(self.denoising_end, float) and 0 < dnv < 1
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps( timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None num_inference_steps,
strength,
device,
denoising_start=self.denoising_start if denoising_value_valid else None,
) )
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
add_noise = True if denoising_start is None else False add_noise = True if self.denoising_start is None else False
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( latents = self.prepare_latents(
image, image,
...@@ -1044,7 +1122,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1044,7 +1122,7 @@ class StableDiffusionXLImg2ImgPipeline(
) )
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
...@@ -1059,30 +1137,30 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1059,30 +1137,30 @@ class StableDiffusionXLImg2ImgPipeline(
# 9.1 Apply denoising_end # 9.1 Apply denoising_end
if ( if (
denoising_end is not None self.denoising_end is not None
and denoising_start is not None and self.denoising_start is not None
and denoising_value_valid(denoising_end) and denoising_value_valid(self.denoising_end)
and denoising_value_valid(denoising_start) and denoising_value_valid(self.denoising_start)
and denoising_start >= denoising_end and self.denoising_start >= self.denoising_end
): ):
raise ValueError( raise ValueError(
f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: " f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ f" {denoising_end} when using type float." + f" {self.denoising_end} when using type float."
) )
elif denoising_end is not None and denoising_value_valid(denoising_end): elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
discrete_timestep_cutoff = int( discrete_timestep_cutoff = int(
round( round(
self.scheduler.config.num_train_timesteps self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps) - (self.denoising_end * self.scheduler.config.num_train_timesteps)
) )
) )
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps] timesteps = timesteps[:num_inference_steps]
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -1092,23 +1170,39 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1092,23 +1170,39 @@ class StableDiffusionXLImg2ImgPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0: if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -301,6 +301,17 @@ class StableDiffusionXLInpaintPipeline( ...@@ -301,6 +301,17 @@ class StableDiffusionXLInpaintPipeline(
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
"mask",
"masked_image_latents",
]
def __init__( def __init__(
self, self,
...@@ -639,6 +650,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -639,6 +650,7 @@ class StableDiffusionXLInpaintPipeline(
negative_prompt_2=None, negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -646,14 +658,19 @@ class StableDiffusionXLInpaintPipeline( ...@@ -646,14 +658,19 @@ class StableDiffusionXLInpaintPipeline(
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -965,6 +982,41 @@ class StableDiffusionXLInpaintPipeline( ...@@ -965,6 +982,41 @@ class StableDiffusionXLInpaintPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def denoising_end(self):
return self._denoising_end
@property
def denoising_start(self):
return self._denoising_start
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -993,8 +1045,6 @@ class StableDiffusionXLInpaintPipeline( ...@@ -993,8 +1045,6 @@ class StableDiffusionXLInpaintPipeline(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None, original_size: Tuple[int, int] = None,
...@@ -1006,6 +1056,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1006,6 +1056,9 @@ class StableDiffusionXLInpaintPipeline(
aesthetic_score: float = 6.0, aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5, negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -1106,12 +1159,6 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1106,12 +1159,6 @@ class StableDiffusionXLInpaintPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
...@@ -1156,6 +1203,15 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1156,6 +1203,15 @@ class StableDiffusionXLInpaintPipeline(
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -1164,6 +1220,23 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1164,6 +1220,23 @@ class StableDiffusionXLInpaintPipeline(
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. `tuple. When returning a tuple, the first element is a list with the generated images. `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
...@@ -1180,8 +1253,16 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1180,8 +1253,16 @@ class StableDiffusionXLInpaintPipeline(
negative_prompt_2, negative_prompt_2,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -1191,14 +1272,10 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1191,14 +1272,10 @@ class StableDiffusionXLInpaintPipeline(
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
( (
...@@ -1211,7 +1288,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1211,7 +1288,7 @@ class StableDiffusionXLInpaintPipeline(
prompt_2=prompt_2, prompt_2=prompt_2,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2, negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
...@@ -1219,16 +1296,19 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1219,16 +1296,19 @@ class StableDiffusionXLInpaintPipeline(
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# 4. set timesteps # 4. set timesteps
def denoising_value_valid(dnv): def denoising_value_valid(dnv):
return isinstance(denoising_end, float) and 0 < dnv < 1 return isinstance(self.denoising_end, float) and 0 < dnv < 1
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps( timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None num_inference_steps,
strength,
device,
denoising_start=self.denoising_start if denoising_value_valid else None,
) )
# check that number of inference steps is not < 1 - as this doesn't make sense # check that number of inference steps is not < 1 - as this doesn't make sense
if num_inference_steps < 1: if num_inference_steps < 1:
...@@ -1260,7 +1340,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1260,7 +1340,7 @@ class StableDiffusionXLInpaintPipeline(
num_channels_unet = self.unet.config.in_channels num_channels_unet = self.unet.config.in_channels
return_image_latents = num_channels_unet == 4 return_image_latents = num_channels_unet == 4
add_noise = True if denoising_start is None else False add_noise = True if self.denoising_start is None else False
latents_outputs = self.prepare_latents( latents_outputs = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
...@@ -1293,7 +1373,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1293,7 +1373,7 @@ class StableDiffusionXLInpaintPipeline(
prompt_embeds.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
do_classifier_free_guidance, self.do_classifier_free_guidance,
) )
# 8. Check that sizes of mask, masked image and latents match # 8. Check that sizes of mask, masked image and latents match
...@@ -1350,7 +1430,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1350,7 +1430,7 @@ class StableDiffusionXLInpaintPipeline(
) )
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
...@@ -1364,30 +1444,31 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1364,30 +1444,31 @@ class StableDiffusionXLInpaintPipeline(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
if ( if (
denoising_end is not None self.denoising_end is not None
and denoising_start is not None and self.denoising_start is not None
and denoising_value_valid(denoising_end) and denoising_value_valid(self.denoising_end)
and denoising_value_valid(denoising_start) and denoising_value_valid(self.denoising_start)
and denoising_start >= denoising_end and self.denoising_start >= self.denoising_end
): ):
raise ValueError( raise ValueError(
f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: " f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ f" {denoising_end} when using type float." + f" {self.denoising_end} when using type float."
) )
elif denoising_end is not None and denoising_value_valid(denoising_end): elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
discrete_timestep_cutoff = int( discrete_timestep_cutoff = int(
round( round(
self.scheduler.config.num_train_timesteps self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps) - (self.denoising_end * self.scheduler.config.num_train_timesteps)
) )
) )
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps] timesteps = timesteps[:num_inference_steps]
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension # concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -1401,26 +1482,26 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1401,26 +1482,26 @@ class StableDiffusionXLInpaintPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0: if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if num_channels_unet == 4: if num_channels_unet == 4:
init_latents_proper = image_latents init_latents_proper = image_latents
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
init_mask, _ = mask.chunk(2) init_mask, _ = mask.chunk(2)
else: else:
init_mask = mask init_mask = mask
...@@ -1433,6 +1514,24 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1433,6 +1514,24 @@ class StableDiffusionXLInpaintPipeline(
latents = (1 - init_mask) * init_latents_proper + init_mask * latents latents = (1 - init_mask) * init_latents_proper + init_mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
mask = callback_outputs.pop("mask", mask)
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -446,16 +446,27 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -446,16 +446,27 @@ class StableDiffusionXLInstructPix2PixPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs
def check_inputs( def check_inputs(
self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self,
): prompt,
if (callback_steps is None) or ( callback_steps,
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
......
...@@ -491,18 +491,24 @@ class StableDiffusionXLAdapterPipeline( ...@@ -491,18 +491,24 @@ class StableDiffusionXLAdapterPipeline(
negative_prompt_embeds=None, negative_prompt_embeds=None,
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
......
...@@ -416,17 +416,22 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -416,17 +416,22 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -471,19 +471,30 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -471,19 +471,30 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self,
prompt,
strength,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
......
...@@ -255,17 +255,22 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -255,17 +255,22 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import logging, replace_example_docstring from ...utils import deprecate, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel from .modeling_paella_vq_model import PaellaVQModel
...@@ -73,6 +73,12 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -73,6 +73,12 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
""" """
model_cpu_offload_seq = "text_encoder->decoder->vqgan" model_cpu_offload_seq = "text_encoder->decoder->vqgan"
_callback_tensor_inputs = [
"latents",
"text_encoder_hidden_states",
"negative_prompt_embeds",
"image_embeddings",
]
def __init__( def __init__(
self, self,
...@@ -187,6 +193,18 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -187,6 +193,18 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
# to avoid doing two forward passes # to avoid doing two forward passes
return text_encoder_hidden_states, uncond_text_encoder_hidden_states return text_encoder_hidden_states, uncond_text_encoder_hidden_states
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -202,8 +220,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -202,8 +220,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_steps: int = 1, callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -242,12 +261,15 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -242,12 +261,15 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*): callback_on_step_end (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be A function that calls at the end of each denoising steps during the inference. The function is called
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_steps (`int`, *optional*, defaults to 1): callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
The frequency at which the `callback` function will be called. If not specified, the callback will be `callback_on_step_end_tensor_inputs`.
called at every step. callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -257,10 +279,33 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -257,10 +279,33 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
embeddings. embeddings.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# 0. Define commonly used variables # 0. Define commonly used variables
device = self._execution_device device = self._execution_device
dtype = self.decoder.dtype dtype = self.decoder.dtype
do_classifier_free_guidance = guidance_scale > 1.0 self._guidance_scale = guidance_scale
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
if not isinstance(prompt, list): if not isinstance(prompt, list):
...@@ -269,7 +314,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -269,7 +314,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
else: else:
raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
if negative_prompt is not None and not isinstance(negative_prompt, list): if negative_prompt is not None and not isinstance(negative_prompt, list):
if isinstance(negative_prompt, str): if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] negative_prompt = [negative_prompt]
...@@ -298,7 +343,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -298,7 +343,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
prompt, prompt,
device, device,
image_embeddings.size(0) * num_images_per_prompt, image_embeddings.size(0) * num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
) )
text_encoder_hidden_states = ( text_encoder_hidden_states = (
...@@ -323,25 +368,26 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -323,25 +368,26 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler) latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler)
# 6. Run denoising loop # 6. Run denoising loop
self._num_timesteps = len(timesteps[:-1])
for i, t in enumerate(self.progress_bar(timesteps[:-1])): for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype) ratio = t.expand(latents.size(0)).to(dtype)
effnet = ( effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
if do_classifier_free_guidance if self.do_classifier_free_guidance
else image_embeddings else image_embeddings
) )
# 7. Denoise latents # 7. Denoise latents
predicted_latents = self.decoder( predicted_latents = self.decoder(
torch.cat([latents] * 2) if do_classifier_free_guidance else latents, torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio,
effnet=effnet, effnet=effnet,
clip=text_encoder_hidden_states, clip=text_encoder_hidden_states,
) )
# 8. Check for classifier free guidance and apply it # 8. Check for classifier free guidance and apply it
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2)
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, guidance_scale) predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
# 9. Renoise latents to next timestep # 9. Renoise latents to next timestep
latents = self.scheduler.step( latents = self.scheduler.step(
...@@ -351,25 +397,41 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -351,25 +397,41 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
generator=generator, generator=generator,
).prev_sample ).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
image_embeddings = callback_outputs.pop("image_embeddings", image_embeddings)
text_encoder_hidden_states = callback_outputs.pop(
"text_encoder_hidden_states", text_encoder_hidden_states
)
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
)
if not output_type == "latent":
# 10. Scale and decode the image latents with vq-vae # 10. Scale and decode the image latents with vq-vae
latents = self.vqgan.config.scale_factor * latents latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1) images = self.vqgan.decode(latents).sample.clamp(0, 1)
# Offload all models
self.maybe_free_model_hooks()
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}")
if output_type == "np": if output_type == "np":
images = images.permute(0, 2, 3, 1).cpu().numpy() images = images.permute(0, 2, 3, 1).cpu().numpy()
elif output_type == "pil": elif output_type == "pil":
images = images.permute(0, 2, 3, 1).cpu().numpy() images = images.permute(0, 2, 3, 1).cpu().numpy()
images = self.numpy_to_pil(images) images = self.numpy_to_pil(images)
else:
images = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return images return images
......
...@@ -11,13 +11,13 @@ ...@@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import replace_example_docstring from ...utils import deprecate, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .modeling_paella_vq_model import PaellaVQModel from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
...@@ -161,10 +161,11 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -161,10 +161,11 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
prior_callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
prior_callback_steps: int = 1, prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_steps: int = 1, callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -226,19 +227,23 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -226,19 +227,23 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
prior_callback (`Callable`, *optional*): prior_callback_on_step_end (`Callable`, *optional*):
A function that will be called every `prior_callback_steps` steps during inference. The function will A function that calls at the end of each denoising steps during the inference. The function is called
be called with the following arguments: `prior_callback(step: int, timestep: int, latents: with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep:
torch.FloatTensor)`. int, callback_kwargs: Dict)`.
prior_callback_steps (`int`, *optional*, defaults to 1): prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
The frequency at which the `callback` function will be called. If not specified, the callback will be The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
called at every step. list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
callback (`Callable`, *optional*): the `._callback_tensor_inputs` attribute of your pipeine class.
A function that will be called every `callback_steps` steps during inference. The function will be callback_on_step_end (`Callable`, *optional*):
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. A function that calls at the end of each denoising steps during the inference. The function is called
callback_steps (`int`, *optional*, defaults to 1): with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
The frequency at which the `callback` function will be called. If not specified, the callback will be callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
called at every step. `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -246,6 +251,22 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -246,6 +251,22 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
prior_kwargs = {}
if kwargs.get("prior_callback", None) is not None:
prior_kwargs["callback"] = kwargs.pop("prior_callback")
deprecate(
"prior_callback",
"1.0.0",
"Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`",
)
if kwargs.get("prior_callback_steps", None) is not None:
deprecate(
"prior_callback_steps",
"1.0.0",
"Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`",
)
prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps")
prior_outputs = self.prior_pipe( prior_outputs = self.prior_pipe(
prompt=prompt if prompt_embeds is None else None, prompt=prompt if prompt_embeds is None else None,
height=height, height=height,
...@@ -261,8 +282,9 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -261,8 +282,9 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
latents=latents, latents=latents,
output_type="pt", output_type="pt",
return_dict=False, return_dict=False,
callback=prior_callback, callback_on_step_end=prior_callback_on_step_end,
callback_steps=prior_callback_steps, callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
**prior_kwargs,
) )
image_embeddings = prior_outputs[0] image_embeddings = prior_outputs[0]
...@@ -276,8 +298,9 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -276,8 +298,9 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
generator=generator, generator=generator,
output_type=output_type, output_type=output_type,
return_dict=return_dict, return_dict=return_dict,
callback=callback, callback_on_step_end=callback_on_step_end,
callback_steps=callback_steps, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
**kwargs,
) )
return outputs return outputs
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import Callable, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -22,11 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -22,11 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import LoraLoaderMixin from ...loaders import LoraLoaderMixin
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import ( from ...utils import BaseOutput, deprecate, logging, replace_example_docstring
BaseOutput,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .modeling_wuerstchen_prior import WuerstchenPrior from .modeling_wuerstchen_prior import WuerstchenPrior
...@@ -94,6 +90,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -94,6 +90,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
unet_name = "prior" unet_name = "prior"
text_encoder_name = "text_encoder" text_encoder_name = "text_encoder"
model_cpu_offload_seq = "text_encoder->prior" model_cpu_offload_seq = "text_encoder->prior"
_callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -264,6 +261,18 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -264,6 +261,18 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
In Case you want to provide explicit timesteps, please use the 'timesteps' argument." In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
) )
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -282,8 +291,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -282,8 +291,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pt", output_type: Optional[str] = "pt",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_steps: int = 1, callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -331,12 +341,15 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -331,12 +341,15 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*): callback_on_step_end (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be A function that calls at the end of each denoising steps during the inference. The function is called
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_steps (`int`, *optional*, defaults to 1): callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
The frequency at which the `callback` function will be called. If not specified, the callback will be `callback_on_step_end_tensor_inputs`.
called at every step. callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -346,9 +359,32 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -346,9 +359,32 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
generated image embeddings. generated image embeddings.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# 0. Define commonly used variables # 0. Define commonly used variables
device = self._execution_device device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0 self._guidance_scale = guidance_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -363,7 +399,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -363,7 +399,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
else: else:
raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
if negative_prompt is not None and not isinstance(negative_prompt, list): if negative_prompt is not None and not isinstance(negative_prompt, list):
if isinstance(negative_prompt, str): if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] negative_prompt = [negative_prompt]
...@@ -376,7 +412,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -376,7 +412,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
prompt, prompt,
negative_prompt, negative_prompt,
num_inference_steps, num_inference_steps,
do_classifier_free_guidance, self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
...@@ -386,7 +422,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -386,7 +422,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
prompt=prompt, prompt=prompt,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
...@@ -419,21 +455,22 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -419,21 +455,22 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler) latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler)
# 6. Run denoising loop # 6. Run denoising loop
self._num_timesteps = len(timesteps[:-1])
for i, t in enumerate(self.progress_bar(timesteps[:-1])): for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype) ratio = t.expand(latents.size(0)).to(dtype)
# 7. Denoise image embeddings # 7. Denoise image embeddings
predicted_image_embedding = self.prior( predicted_image_embedding = self.prior(
torch.cat([latents] * 2) if do_classifier_free_guidance else latents, torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio, r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio,
c=text_encoder_hidden_states, c=text_encoder_hidden_states,
) )
# 8. Check for classifier free guidance and apply it # 8. Check for classifier free guidance and apply it
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2)
predicted_image_embedding = torch.lerp( predicted_image_embedding = torch.lerp(
predicted_image_embedding_uncond, predicted_image_embedding_text, guidance_scale predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale
) )
# 9. Renoise latents to next timestep # 9. Renoise latents to next timestep
...@@ -444,6 +481,18 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -444,6 +481,18 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
generator=generator, generator=generator,
).prev_sample ).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
text_encoder_hidden_states = callback_outputs.pop(
"text_encoder_hidden_states", text_encoder_hidden_states
)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
......
...@@ -27,7 +27,12 @@ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( ...@@ -27,7 +27,12 @@ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
) )
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
...@@ -42,6 +47,7 @@ class AltDiffusionPipelineFastTests( ...@@ -42,6 +47,7 @@ class AltDiffusionPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment