Unverified Commit 2b23ec82 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add callbacks to denoising step (#5427)



* draft1

* update

* style

* move to the end of loop

* update

* update callbak_on_step_end_inputs

* Revert "update"

This reverts commit 5f9b153183d0cde3b850f14024d2e37ae8c19576.

* Revert "update callbak_on_step_end_inputs"

This reverts commit 44889f4dabad95b7ebb330faa5f1955b5d008c88.

* update

* update test required_optional_params

* remove self.lora_scale

* img2img

* inpaint

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix

* apply feedbacks on img2img + inpaint: keep only important pipeline attributes

* depth

* pix2pix

* make _callback_tensor_inputs an class variable so that we can use it for testing

* add a basic tst for callback

* add a read-only tensor input timesteps + fix tests

* add second test for callback cfg

* sdxl

* sdxl img2img

* sdxl inpaint

* kandinsky prior

* kandinsky decoder

* kandinsky img2img + combined

* kandinsky inpaint

* fix copies

* fix

* consistent default inputs

* fix copies

* wuerstchen_prior prior

* test_wuerstchen_decoder + fix test for prior

* wuerstchen_combined pipeline + skip tests

* skip test for kandinsky combined

* lcm

* remove timesteps etc

* add doc string

* copies

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* make style and improve tests

* up

* up

* fix more

* fix cfg test

* tests for callbacks

* fix for real

* update

* lcm img2img

* add doc

* add doc page to index

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 080081bd
...@@ -76,6 +76,8 @@ ...@@ -76,6 +76,8 @@
title: Kandinsky title: Kandinsky
- local: using-diffusers/controlnet - local: using-diffusers/controlnet
title: ControlNet title: ControlNet
- local: using-diffusers/callback
title: Callback
- local: using-diffusers/shap-e - local: using-diffusers/shap-e
title: Shap-E title: Shap-E
- local: using-diffusers/diffedit - local: using-diffusers/diffedit
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Using callback
[[open-in-colab]]
Most 🤗 Diffusers pipeline now accept a `callback_on_step_end` argument that allows you to change the default behavior of denoising loop with custom defined functions. Here is an example of a callback function we can write to disable classifier free guidance after 40% of inference steps to save compute with minimum tradeoff in performance.
```python
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
# adjust the batch_size of prompt_embeds according to guidance_scale
if step_index == int(pipe.num_timestep * 0.4):
prompt_embeds = callback_kwargs["prompt_embeds"]
prompt_embeds =prompt_embeds.chunk(2)[-1]
# update guidance_scale and prompt_embeds
pipe._guidance_scale = 0.0
callback_kwargs["prompt_embeds"] = prompt_embeds
return callback_kwargs
```
Your callback function has below arguments:
* `pipe` is the pipeline instance, which provides access to useful properties such as `num_timestep` and `guidance_scale`. You can modify these properties by updating the underlying attributes. In this example, we disable CFG by setting `pipe._guidance_scale` to be `0`.
* `step_index` and `timestep` tell you where you are in the denoising loop. In our example, we use `step_index` to decide when to turn off CFG.
* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables so please check the pipeline class's `_callback_tensor_inputs` attribute for the list of variables that you can modify. Common variables include `latents` and `prompt_embeds`. In our example, we need to adjust the batch size of `prompt_embeds` after setting `guidance_scale` to be `0` in order for it to work properly.
You can pass the callback function as `callback_on_step_end` argument to the pipeline along with `callback_on_step_end_tensor_inputs`.
```
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
generator = torch.Generator(device="cuda").manual_seed(1)
out= pipe(prompt, generator=generator, callback_on_step_end = callback_custom_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds'])
out.images[0].save("out_custom_cfg.png")
```
Your callback function will be executed at the end of each denoising step and modify pipeline attributes and tensor variables for the next denoising step. We successfully added the "dynamic CFG" feature to the stable diffusion pipeline without having to modify the code at all.
<Tip>
Currently we only support `callback_on_step_end`. If you have a solid use case and require a callback function with a different execution point, please open an [feature request](https://github.com/huggingface/diffusers/issues/new/choose) so we can add it!
</Tip>
\ No newline at end of file
...@@ -110,6 +110,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -110,6 +110,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -500,17 +501,23 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -500,17 +501,23 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found"
f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
...@@ -581,6 +588,33 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -581,6 +588,33 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -599,11 +633,12 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -599,11 +633,12 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -647,12 +682,6 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -647,12 +682,6 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
...@@ -663,6 +692,15 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -663,6 +692,15 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -673,6 +711,25 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -673,6 +711,25 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using"
" `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using"
" `callback_on_step_end`",
)
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
...@@ -680,9 +737,21 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -680,9 +737,21 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -692,29 +761,27 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -692,29 +761,27 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale, lora_scale=lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
...@@ -739,10 +806,11 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -739,10 +806,11 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# 7. Denoising loop # 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
...@@ -750,22 +818,32 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -750,22 +818,32 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0: if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -148,6 +148,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -148,6 +148,7 @@ class AltDiffusionImg2ImgPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -501,19 +502,31 @@ class AltDiffusionImg2ImgPipeline( ...@@ -501,19 +502,31 @@ class AltDiffusionImg2ImgPipeline(
return extra_step_kwargs return extra_step_kwargs
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self,
prompt,
strength,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found"
f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -633,6 +646,29 @@ class AltDiffusionImg2ImgPipeline( ...@@ -633,6 +646,29 @@ class AltDiffusionImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -650,10 +686,11 @@ class AltDiffusionImg2ImgPipeline( ...@@ -650,10 +686,11 @@ class AltDiffusionImg2ImgPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = None, clip_skip: int = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -701,18 +738,21 @@ class AltDiffusionImg2ImgPipeline( ...@@ -701,18 +738,21 @@ class AltDiffusionImg2ImgPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
Returns: Returns:
...@@ -722,8 +762,39 @@ class AltDiffusionImg2ImgPipeline( ...@@ -722,8 +762,39 @@ class AltDiffusionImg2ImgPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use"
" `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use"
" `callback_on_step_end`",
)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) self.check_inputs(
prompt,
strength,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -732,31 +803,28 @@ class AltDiffusionImg2ImgPipeline( ...@@ -732,31 +803,28 @@ class AltDiffusionImg2ImgPipeline(
batch_size = len(prompt) batch_size = len(prompt)
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Preprocess image # 4. Preprocess image
...@@ -769,7 +837,13 @@ class AltDiffusionImg2ImgPipeline( ...@@ -769,7 +837,13 @@ class AltDiffusionImg2ImgPipeline(
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator image,
latent_timestep,
batch_size,
num_images_per_prompt,
prompt_embeds.dtype,
device,
generator,
) )
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
...@@ -777,10 +851,11 @@ class AltDiffusionImg2ImgPipeline( ...@@ -777,10 +851,11 @@ class AltDiffusionImg2ImgPipeline(
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
...@@ -788,18 +863,28 @@ class AltDiffusionImg2ImgPipeline( ...@@ -788,18 +863,28 @@ class AltDiffusionImg2ImgPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -426,17 +426,22 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo ...@@ -426,17 +426,22 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -12,16 +12,13 @@ ...@@ -12,16 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import torch import torch
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import deprecate, logging, replace_example_docstring
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -81,6 +78,7 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -81,6 +78,7 @@ class KandinskyV22Pipeline(DiffusionPipeline):
""" """
model_cpu_offload_seq = "unet->movq" model_cpu_offload_seq = "unet->movq"
_callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"]
def __init__( def __init__(
self, self,
...@@ -109,6 +107,18 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -109,6 +107,18 @@ class KandinskyV22Pipeline(DiffusionPipeline):
latents = latents * scheduler.init_noise_sigma latents = latents * scheduler.init_noise_sigma
return latents return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -123,9 +133,10 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -123,9 +133,10 @@ class KandinskyV22Pipeline(DiffusionPipeline):
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
return_dict: bool = True, return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -160,23 +171,50 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -160,23 +171,50 @@ class KandinskyV22Pipeline(DiffusionPipeline):
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
Returns: Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] or `tuple`
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
device = self._execution_device device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0 self._guidance_scale = guidance_scale
if isinstance(image_embeds, list): if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
...@@ -184,7 +222,7 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -184,7 +222,7 @@ class KandinskyV22Pipeline(DiffusionPipeline):
if isinstance(negative_image_embeds, list): if isinstance(negative_image_embeds, list):
negative_image_embeds = torch.cat(negative_image_embeds, dim=0) negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
...@@ -193,7 +231,7 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -193,7 +231,7 @@ class KandinskyV22Pipeline(DiffusionPipeline):
) )
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps_tensor = self.scheduler.timesteps timesteps = self.scheduler.timesteps
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
...@@ -209,9 +247,10 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -209,9 +247,10 @@ class KandinskyV22Pipeline(DiffusionPipeline):
self.scheduler, self.scheduler,
) )
for i, t in enumerate(self.progress_bar(timesteps_tensor)): self._num_timesteps = len(timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
added_cond_kwargs = {"image_embeds": image_embeds} added_cond_kwargs = {"image_embeds": image_embeds}
noise_pred = self.unet( noise_pred = self.unet(
...@@ -222,11 +261,11 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -222,11 +261,11 @@ class KandinskyV22Pipeline(DiffusionPipeline):
return_dict=False, return_dict=False,
)[0] )[0]
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not ( if not (
...@@ -243,17 +282,26 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -243,17 +282,26 @@ class KandinskyV22Pipeline(DiffusionPipeline):
generator=generator, generator=generator,
)[0] )[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
image_embeds = callback_outputs.pop("image_embeds", image_embeds)
negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds)
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
self.maybe_free_model_hooks()
if output_type not in ["pt", "np", "pil"]: if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
if not output_type == "latent":
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type in ["np", "pil"]: if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5 image = image * 0.5 + 0.5
image = image.clamp(0, 1) image = image.clamp(0, 1)
...@@ -261,6 +309,10 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -261,6 +309,10 @@ class KandinskyV22Pipeline(DiffusionPipeline):
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
else:
image = latents
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -20,10 +20,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo ...@@ -20,10 +20,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...models import PriorTransformer, UNet2DConditionModel, VQModel from ...models import PriorTransformer, UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler, UnCLIPScheduler from ...schedulers import DDPMScheduler, UnCLIPScheduler
from ...utils import ( from ...utils import deprecate, logging, replace_example_docstring
logging,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .pipeline_kandinsky2_2 import KandinskyV22Pipeline from .pipeline_kandinsky2_2 import KandinskyV22Pipeline
from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline
...@@ -220,6 +217,10 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline): ...@@ -220,6 +217,10 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
return_dict: bool = True, return_dict: bool = True,
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -264,14 +265,25 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline): ...@@ -264,14 +265,25 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
prior_callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference of the prior pipeline.
The function is called with the following arguments: `prior_callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
the `._callback_tensor_inputs` attribute of your prior pipeline class.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference of the decoder pipeline.
The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline,
step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors
as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -288,6 +300,8 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline): ...@@ -288,6 +300,8 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
guidance_scale=prior_guidance_scale, guidance_scale=prior_guidance_scale,
output_type="pt", output_type="pt",
return_dict=False, return_dict=False,
callback_on_step_end=prior_callback_on_step_end,
callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
) )
image_embeds = prior_outputs[0] image_embeds = prior_outputs[0]
negative_image_embeds = prior_outputs[1] negative_image_embeds = prior_outputs[1]
...@@ -309,6 +323,8 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline): ...@@ -309,6 +323,8 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
callback=callback, callback=callback,
callback_steps=callback_steps, callback_steps=callback_steps,
return_dict=return_dict, return_dict=return_dict,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
) )
return outputs return outputs
...@@ -438,6 +454,10 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline): ...@@ -438,6 +454,10 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
return_dict: bool = True, return_dict: bool = True,
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -516,6 +536,8 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline): ...@@ -516,6 +536,8 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
guidance_scale=prior_guidance_scale, guidance_scale=prior_guidance_scale,
output_type="pt", output_type="pt",
return_dict=False, return_dict=False,
callback_on_step_end=prior_callback_on_step_end,
callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
) )
image_embeds = prior_outputs[0] image_embeds = prior_outputs[0]
negative_image_embeds = prior_outputs[1] negative_image_embeds = prior_outputs[1]
...@@ -547,6 +569,8 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline): ...@@ -547,6 +569,8 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
callback=callback, callback=callback,
callback_steps=callback_steps, callback_steps=callback_steps,
return_dict=return_dict, return_dict=return_dict,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
) )
return outputs return outputs
...@@ -663,9 +687,12 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline): ...@@ -663,9 +687,12 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
return_dict: bool = True, return_dict: bool = True,
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -719,20 +746,48 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline): ...@@ -719,20 +746,48 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
prior_callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep:
int, callback_kwargs: Dict)`.
prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
the `._callback_tensor_inputs` attribute of your pipeine class.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
Returns: Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] or `tuple`
""" """
prior_kwargs = {}
if kwargs.get("prior_callback", None) is not None:
prior_kwargs["callback"] = kwargs.pop("prior_callback")
deprecate(
"prior_callback",
"1.0.0",
"Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`",
)
if kwargs.get("prior_callback_steps", None) is not None:
deprecate(
"prior_callback_steps",
"1.0.0",
"Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`",
)
prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps")
prior_outputs = self.prior_pipe( prior_outputs = self.prior_pipe(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
...@@ -743,6 +798,9 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline): ...@@ -743,6 +798,9 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
guidance_scale=prior_guidance_scale, guidance_scale=prior_guidance_scale,
output_type="pt", output_type="pt",
return_dict=False, return_dict=False,
callback_on_step_end=prior_callback_on_step_end,
callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
**prior_kwargs,
) )
image_embeds = prior_outputs[0] image_embeds = prior_outputs[0]
negative_image_embeds = prior_outputs[1] negative_image_embeds = prior_outputs[1]
...@@ -779,8 +837,9 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline): ...@@ -779,8 +837,9 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
generator=generator, generator=generator,
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
output_type=output_type, output_type=output_type,
callback=callback,
callback_steps=callback_steps,
return_dict=return_dict, return_dict=return_dict,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
**kwargs,
) )
return outputs return outputs
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -21,9 +21,7 @@ from PIL import Image ...@@ -21,9 +21,7 @@ from PIL import Image
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import deprecate, logging
logging,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -108,6 +106,7 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -108,6 +106,7 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
""" """
model_cpu_offload_seq = "unet->movq" model_cpu_offload_seq = "unet->movq"
_callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"]
def __init__( def __init__(
self, self,
...@@ -176,6 +175,18 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -176,6 +175,18 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
return latents return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -190,9 +201,10 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -190,9 +201,10 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
return_dict: bool = True, return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -233,23 +245,50 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -233,23 +245,50 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
Returns: Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] or `tuple`
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
device = self._execution_device device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0 self._guidance_scale = guidance_scale
if isinstance(image_embeds, list): if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
...@@ -257,7 +296,7 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -257,7 +296,7 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
if isinstance(negative_image_embeds, list): if isinstance(negative_image_embeds, list):
negative_image_embeds = torch.cat(negative_image_embeds, dim=0) negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
...@@ -284,9 +323,10 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -284,9 +323,10 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
latents = self.prepare_latents( latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
) )
self._num_timesteps = len(timesteps)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
added_cond_kwargs = {"image_embeds": image_embeds} added_cond_kwargs = {"image_embeds": image_embeds}
noise_pred = self.unet( noise_pred = self.unet(
...@@ -297,11 +337,11 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -297,11 +337,11 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
return_dict=False, return_dict=False,
)[0] )[0]
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not ( if not (
...@@ -318,19 +358,28 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -318,19 +358,28 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
generator=generator, generator=generator,
)[0] )[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
image_embeds = callback_outputs.pop("image_embeds", image_embeds)
negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds)
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
)
if not output_type == "latent":
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
# Offload all models
self.maybe_free_model_hooks()
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
if output_type in ["np", "pil"]: if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5 image = image * 0.5 + 0.5
image = image.clamp(0, 1) image = image.clamp(0, 1)
...@@ -338,6 +387,11 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -338,6 +387,11 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
else:
image = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from copy import deepcopy from copy import deepcopy
from typing import Callable, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -25,9 +25,7 @@ from PIL import Image ...@@ -25,9 +25,7 @@ from PIL import Image
from ... import __version__ from ... import __version__
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import deprecate, logging
logging,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -251,6 +249,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -251,6 +249,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
""" """
model_cpu_offload_seq = "unet->movq" model_cpu_offload_seq = "unet->movq"
_callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds", "masked_image", "mask_image"]
def __init__( def __init__(
self, self,
...@@ -280,6 +279,18 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -280,6 +279,18 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
latents = latents * scheduler.init_noise_sigma latents = latents * scheduler.init_noise_sigma
return latents return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -295,9 +306,10 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -295,9 +306,10 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
return_dict: bool = True, return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -340,14 +352,17 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -340,14 +352,17 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -367,9 +382,32 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -367,9 +382,32 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
) )
self._warn_has_been_called = True self._warn_has_been_called = True
device = self._execution_device callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
self._guidance_scale = guidance_scale
do_classifier_free_guidance = guidance_scale > 1.0 device = self._execution_device
if isinstance(image_embeds, list): if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
...@@ -377,7 +415,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -377,7 +415,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
if isinstance(negative_image_embeds, list): if isinstance(negative_image_embeds, list):
negative_image_embeds = torch.cat(negative_image_embeds, dim=0) negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
...@@ -386,7 +424,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -386,7 +424,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
) )
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps_tensor = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# preprocess image and mask # preprocess image and mask
mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width)
...@@ -407,7 +445,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -407,7 +445,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0)
masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
mask_image = mask_image.repeat(2, 1, 1, 1) mask_image = mask_image.repeat(2, 1, 1, 1)
masked_image = masked_image.repeat(2, 1, 1, 1) masked_image = masked_image.repeat(2, 1, 1, 1)
...@@ -425,9 +463,11 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -425,9 +463,11 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
self.scheduler, self.scheduler,
) )
noise = torch.clone(latents) noise = torch.clone(latents)
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
self._num_timesteps = len(timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1)
added_cond_kwargs = {"image_embeds": image_embeds} added_cond_kwargs = {"image_embeds": image_embeds}
...@@ -439,11 +479,11 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -439,11 +479,11 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
return_dict=False, return_dict=False,
)[0] )[0]
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not ( if not (
...@@ -462,27 +502,40 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -462,27 +502,40 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
init_latents_proper = image[:1] init_latents_proper = image[:1]
init_mask = mask_image[:1] init_mask = mask_image[:1]
if i < len(timesteps_tensor) - 1: if i < len(timesteps) - 1:
noise_timestep = timesteps_tensor[i + 1] noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.add_noise( init_latents_proper = self.scheduler.add_noise(
init_latents_proper, noise, torch.tensor([noise_timestep]) init_latents_proper, noise, torch.tensor([noise_timestep])
) )
latents = init_mask * init_latents_proper + (1 - init_mask) * latents latents = init_mask * init_latents_proper + (1 - init_mask) * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
image_embeds = callback_outputs.pop("image_embeds", image_embeds)
negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds)
masked_image = callback_outputs.pop("masked_image", masked_image)
mask_image = callback_outputs.pop("mask_image", mask_image)
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
# post-processing # post-processing
latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
# Offload all models if output_type not in ["pt", "np", "pil", "latent"]:
self.maybe_free_model_hooks() raise ValueError(
f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
)
if output_type not in ["pt", "np", "pil"]: if not output_type == "latent":
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type in ["np", "pil"]: if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5 image = image * 0.5 + 0.5
...@@ -491,6 +544,11 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline): ...@@ -491,6 +544,11 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
else:
image = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (image,) return (image,)
......
from typing import List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -106,6 +106,7 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -106,6 +106,7 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->image_encoder->prior" model_cpu_offload_seq = "text_encoder->image_encoder->prior"
_exclude_from_cpu_offload = ["prior"] _exclude_from_cpu_offload = ["prior"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "text_encoder_hidden_states", "text_mask"]
def __init__( def __init__(
self, self,
...@@ -354,6 +355,18 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -354,6 +355,18 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
return prompt_embeds, text_encoder_hidden_states, text_mask return prompt_embeds, text_encoder_hidden_states, text_mask
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -367,6 +380,8 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -367,6 +380,8 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
guidance_scale: float = 4.0, guidance_scale: float = 4.0,
output_type: Optional[str] = "pt", # pt only output_type: Optional[str] = "pt", # pt only
return_dict: bool = True, return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -400,6 +415,15 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -400,6 +415,15 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
(`torch.Tensor`). (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -407,6 +431,13 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -407,6 +431,13 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
[`KandinskyPriorPipelineOutput`] or `tuple` [`KandinskyPriorPipelineOutput`] or `tuple`
""" """
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = [prompt] prompt = [prompt]
elif not isinstance(prompt, list): elif not isinstance(prompt, list):
...@@ -428,14 +459,15 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -428,14 +459,15 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
batch_size = len(prompt) batch_size = len(prompt)
batch_size = batch_size * num_images_per_prompt batch_size = batch_size * num_images_per_prompt
do_classifier_free_guidance = guidance_scale > 1.0 self._guidance_scale = guidance_scale
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt
) )
# prior # prior
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
prior_timesteps_tensor = self.scheduler.timesteps timesteps = self.scheduler.timesteps
embedding_dim = self.prior.config.embedding_dim embedding_dim = self.prior.config.embedding_dim
...@@ -447,10 +479,10 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -447,10 +479,10 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
latents, latents,
self.scheduler, self.scheduler,
) )
self._num_timesteps = len(timesteps)
for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
predicted_image_embedding = self.prior( predicted_image_embedding = self.prior(
latent_model_input, latent_model_input,
...@@ -460,16 +492,16 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -460,16 +492,16 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
attention_mask=text_mask, attention_mask=text_mask,
).predicted_image_embedding ).predicted_image_embedding
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding = predicted_image_embedding_uncond + self.guidance_scale * (
predicted_image_embedding_text - predicted_image_embedding_uncond predicted_image_embedding_text - predicted_image_embedding_uncond
) )
if i + 1 == prior_timesteps_tensor.shape[0]: if i + 1 == timesteps.shape[0]:
prev_timestep = None prev_timestep = None
else: else:
prev_timestep = prior_timesteps_tensor[i + 1] prev_timestep = timesteps[i + 1]
latents = self.scheduler.step( latents = self.scheduler.step(
predicted_image_embedding, predicted_image_embedding,
...@@ -479,6 +511,19 @@ class KandinskyV22PriorPipeline(DiffusionPipeline): ...@@ -479,6 +511,19 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
prev_timestep=prev_timestep, prev_timestep=prev_timestep,
).prev_sample ).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
text_encoder_hidden_states = callback_outputs.pop(
"text_encoder_hidden_states", text_encoder_hidden_states
)
text_mask = callback_outputs.pop("text_mask", text_mask)
latents = self.prior.post_process_latents(latents) latents = self.prior.post_process_latents(latents)
image_embeddings = latents image_embeddings = latents
......
...@@ -118,6 +118,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -118,6 +118,7 @@ class LatentConsistencyModelImg2ImgPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]
def __init__( def __init__(
self, self,
...@@ -535,18 +536,24 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -535,18 +536,24 @@ class LatentConsistencyModelImg2ImgPipeline(
strength: float, strength: float,
callback_steps: int, callback_steps: int,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -559,6 +566,22 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -559,6 +566,22 @@ class LatentConsistencyModelImg2ImgPipeline(
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def clip_skip(self):
return self._clip_skip
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -575,10 +598,11 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -575,10 +598,11 @@ class LatentConsistencyModelImg2ImgPipeline(
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -621,18 +645,21 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -621,18 +645,21 @@ class LatentConsistencyModelImg2ImgPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -643,8 +670,27 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -643,8 +670,27 @@ class LatentConsistencyModelImg2ImgPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, strength, callback_steps, prompt_embeds) self.check_inputs(prompt, strength, callback_steps, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -658,7 +704,9 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -658,7 +704,9 @@ class LatentConsistencyModelImg2ImgPipeline(
# do_classifier_free_guidance = guidance_scale > 1.0 # do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
# NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided # NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided
# distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the # distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the
...@@ -672,7 +720,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -672,7 +720,7 @@ class LatentConsistencyModelImg2ImgPipeline(
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=None, negative_prompt_embeds=None,
lora_scale=lora_scale, lora_scale=lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# 4. Encode image # 4. Encode image
...@@ -700,7 +748,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -700,7 +748,7 @@ class LatentConsistencyModelImg2ImgPipeline(
# NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper # NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper
# CFG formulation, so we need to subtract 1 from the input guidance_scale. # CFG formulation, so we need to subtract 1 from the input guidance_scale.
# LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG) # LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG)
w = torch.tensor(guidance_scale - 1).repeat(bs) w = torch.tensor(self.guidance_scale - 1).repeat(bs)
w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to( w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to(
device=device, dtype=latents.dtype device=device, dtype=latents.dtype
) )
...@@ -710,6 +758,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -710,6 +758,7 @@ class LatentConsistencyModelImg2ImgPipeline(
# 8. LCM Multistep Sampling Loop # 8. LCM Multistep Sampling Loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
latents = latents.to(prompt_embeds.dtype) latents = latents.to(prompt_embeds.dtype)
...@@ -720,12 +769,22 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -720,12 +769,22 @@ class LatentConsistencyModelImg2ImgPipeline(
t, t,
timestep_cond=w_embedding, timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False) latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
w_embedding = callback_outputs.pop("w_embedding", w_embedding)
denoised = callback_outputs.pop("denoised", denoised)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -26,7 +26,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa ...@@ -26,7 +26,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LCMScheduler from ...schedulers import LCMScheduler
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
...@@ -93,6 +100,7 @@ class LatentConsistencyModelPipeline( ...@@ -93,6 +100,7 @@ class LatentConsistencyModelPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]
def __init__( def __init__(
self, self,
...@@ -466,18 +474,24 @@ class LatentConsistencyModelPipeline( ...@@ -466,18 +474,24 @@ class LatentConsistencyModelPipeline(
width: int, width: int,
callback_steps: int, callback_steps: int,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -490,6 +504,22 @@ class LatentConsistencyModelPipeline( ...@@ -490,6 +504,22 @@ class LatentConsistencyModelPipeline(
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def clip_skip(self):
return self._clip_skip
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -506,10 +536,11 @@ class LatentConsistencyModelPipeline( ...@@ -506,10 +536,11 @@ class LatentConsistencyModelPipeline(
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -552,18 +583,21 @@ class LatentConsistencyModelPipeline( ...@@ -552,18 +583,21 @@ class LatentConsistencyModelPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -574,12 +608,32 @@ class LatentConsistencyModelPipeline( ...@@ -574,12 +608,32 @@ class LatentConsistencyModelPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps, prompt_embeds) self.check_inputs(prompt, height, width, callback_steps, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -593,7 +647,9 @@ class LatentConsistencyModelPipeline( ...@@ -593,7 +647,9 @@ class LatentConsistencyModelPipeline(
# do_classifier_free_guidance = guidance_scale > 1.0 # do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
# NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided # NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided
# distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the # distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the
...@@ -607,7 +663,7 @@ class LatentConsistencyModelPipeline( ...@@ -607,7 +663,7 @@ class LatentConsistencyModelPipeline(
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=None, negative_prompt_embeds=None,
lora_scale=lora_scale, lora_scale=lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# 4. Prepare timesteps # 4. Prepare timesteps
...@@ -632,7 +688,7 @@ class LatentConsistencyModelPipeline( ...@@ -632,7 +688,7 @@ class LatentConsistencyModelPipeline(
# NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper # NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper
# CFG formulation, so we need to subtract 1 from the input guidance_scale. # CFG formulation, so we need to subtract 1 from the input guidance_scale.
# LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG) # LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG)
w = torch.tensor(guidance_scale - 1).repeat(bs) w = torch.tensor(self.guidance_scale - 1).repeat(bs)
w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to( w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to(
device=device, dtype=latents.dtype device=device, dtype=latents.dtype
) )
...@@ -642,6 +698,7 @@ class LatentConsistencyModelPipeline( ...@@ -642,6 +698,7 @@ class LatentConsistencyModelPipeline(
# 8. LCM MultiStep Sampling Loop: # 8. LCM MultiStep Sampling Loop:
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
latents = latents.to(prompt_embeds.dtype) latents = latents.to(prompt_embeds.dtype)
...@@ -652,12 +709,22 @@ class LatentConsistencyModelPipeline( ...@@ -652,12 +709,22 @@ class LatentConsistencyModelPipeline(
t, t,
timestep_cond=w_embedding, timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False) latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
w_embedding = callback_outputs.pop("w_embedding", w_embedding)
denoised = callback_outputs.pop("denoised", denoised)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -146,17 +146,22 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): ...@@ -146,17 +146,22 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -444,19 +444,30 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -444,19 +444,30 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self,
prompt,
strength,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
......
...@@ -105,6 +105,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -105,6 +105,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -489,17 +490,22 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -489,17 +490,22 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
...@@ -570,6 +576,33 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -570,6 +576,33 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -588,11 +621,12 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -588,11 +621,12 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -636,12 +670,6 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -636,12 +670,6 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
...@@ -652,6 +680,15 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -652,6 +680,15 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
...@@ -662,6 +699,23 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -662,6 +699,23 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
...@@ -669,9 +723,21 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -669,9 +723,21 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -681,29 +747,27 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -681,29 +747,27 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale, lora_scale=lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
...@@ -728,10 +792,11 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -728,10 +792,11 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# 7. Denoising loop # 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
...@@ -739,22 +804,32 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -739,22 +804,32 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0: if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -96,6 +96,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -96,6 +96,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "depth_mask"]
def __init__( def __init__(
self, self,
...@@ -404,19 +405,30 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -404,19 +405,30 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self,
prompt,
strength,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -556,6 +568,29 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -556,6 +568,29 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map
return depth_map return depth_map
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -573,10 +608,11 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -573,10 +608,11 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -624,18 +660,21 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -624,18 +660,21 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
```py ```py
...@@ -664,6 +703,23 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -664,6 +703,23 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images. otherwise a `tuple` is returned where the first element is a list with the generated images.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 1. Check inputs # 1. Check inputs
self.check_inputs( self.check_inputs(
prompt, prompt,
...@@ -672,8 +728,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -672,8 +728,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
if image is None: if image is None:
raise ValueError("`image` input cannot be undefined.") raise ValueError("`image` input cannot be undefined.")
...@@ -686,30 +747,26 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -686,30 +747,26 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare depth mask # 4. Prepare depth mask
...@@ -717,7 +774,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -717,7 +774,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
image, image,
depth_map, depth_map,
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
prompt_embeds.dtype, prompt_embeds.dtype,
device, device,
) )
...@@ -740,10 +797,11 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -740,10 +797,11 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
# 9. Denoising loop # 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)
...@@ -752,18 +810,29 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -752,18 +810,29 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
depth_mask = callback_outputs.pop("depth_mask", depth_mask)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -483,17 +483,22 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline): ...@@ -483,17 +483,22 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
......
...@@ -142,6 +142,7 @@ class StableDiffusionImg2ImgPipeline( ...@@ -142,6 +142,7 @@ class StableDiffusionImg2ImgPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__( def __init__(
self, self,
...@@ -494,19 +495,30 @@ class StableDiffusionImg2ImgPipeline( ...@@ -494,19 +495,30 @@ class StableDiffusionImg2ImgPipeline(
return extra_step_kwargs return extra_step_kwargs
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self,
prompt,
strength,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -628,6 +640,29 @@ class StableDiffusionImg2ImgPipeline( ...@@ -628,6 +640,29 @@ class StableDiffusionImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -645,10 +680,11 @@ class StableDiffusionImg2ImgPipeline( ...@@ -645,10 +680,11 @@ class StableDiffusionImg2ImgPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = None, clip_skip: int = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -696,18 +732,21 @@ class StableDiffusionImg2ImgPipeline( ...@@ -696,18 +732,21 @@ class StableDiffusionImg2ImgPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
Returns: Returns:
...@@ -717,8 +756,37 @@ class StableDiffusionImg2ImgPipeline( ...@@ -717,8 +756,37 @@ class StableDiffusionImg2ImgPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) self.check_inputs(
prompt,
strength,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -727,31 +795,28 @@ class StableDiffusionImg2ImgPipeline( ...@@ -727,31 +795,28 @@ class StableDiffusionImg2ImgPipeline(
batch_size = len(prompt) batch_size = len(prompt)
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Preprocess image # 4. Preprocess image
...@@ -764,7 +829,13 @@ class StableDiffusionImg2ImgPipeline( ...@@ -764,7 +829,13 @@ class StableDiffusionImg2ImgPipeline(
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator image,
latent_timestep,
batch_size,
num_images_per_prompt,
prompt_embeds.dtype,
device,
generator,
) )
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
...@@ -772,10 +843,11 @@ class StableDiffusionImg2ImgPipeline( ...@@ -772,10 +843,11 @@ class StableDiffusionImg2ImgPipeline(
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
...@@ -783,18 +855,28 @@ class StableDiffusionImg2ImgPipeline( ...@@ -783,18 +855,28 @@ class StableDiffusionImg2ImgPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -205,6 +205,7 @@ class StableDiffusionInpaintPipeline( ...@@ -205,6 +205,7 @@ class StableDiffusionInpaintPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"]
def __init__( def __init__(
self, self,
...@@ -562,6 +563,7 @@ class StableDiffusionInpaintPipeline( ...@@ -562,6 +563,7 @@ class StableDiffusionInpaintPipeline(
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -569,14 +571,19 @@ class StableDiffusionInpaintPipeline( ...@@ -569,14 +571,19 @@ class StableDiffusionInpaintPipeline(
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -758,6 +765,29 @@ class StableDiffusionInpaintPipeline( ...@@ -758,6 +765,29 @@ class StableDiffusionInpaintPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -779,10 +809,11 @@ class StableDiffusionInpaintPipeline( ...@@ -779,10 +809,11 @@ class StableDiffusionInpaintPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = None, clip_skip: int = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -846,18 +877,21 @@ class StableDiffusionInpaintPipeline( ...@@ -846,18 +877,21 @@ class StableDiffusionInpaintPipeline(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples: Examples:
```py ```py
...@@ -896,6 +930,23 @@ class StableDiffusionInpaintPipeline( ...@@ -896,6 +930,23 @@ class StableDiffusionInpaintPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content. "not-safe-for-work" (nsfw) content.
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
...@@ -910,8 +961,13 @@ class StableDiffusionInpaintPipeline( ...@@ -910,8 +961,13 @@ class StableDiffusionInpaintPipeline(
negative_prompt, negative_prompt,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -921,10 +977,6 @@ class StableDiffusionInpaintPipeline( ...@@ -921,10 +977,6 @@ class StableDiffusionInpaintPipeline(
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
...@@ -934,17 +986,17 @@ class StableDiffusionInpaintPipeline( ...@@ -934,17 +986,17 @@ class StableDiffusionInpaintPipeline(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=self.clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. set timesteps # 4. set timesteps
...@@ -1011,7 +1063,7 @@ class StableDiffusionInpaintPipeline( ...@@ -1011,7 +1063,7 @@ class StableDiffusionInpaintPipeline(
prompt_embeds.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
do_classifier_free_guidance, self.do_classifier_free_guidance,
) )
# 8. Check that sizes of mask, masked image and latents match # 8. Check that sizes of mask, masked image and latents match
...@@ -1037,10 +1089,11 @@ class StableDiffusionInpaintPipeline( ...@@ -1037,10 +1089,11 @@ class StableDiffusionInpaintPipeline(
# 10. Denoising loop # 10. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension # concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -1053,20 +1106,20 @@ class StableDiffusionInpaintPipeline( ...@@ -1053,20 +1106,20 @@ class StableDiffusionInpaintPipeline(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if num_channels_unet == 4: if num_channels_unet == 4:
init_latents_proper = image_latents init_latents_proper = image_latents
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
init_mask, _ = mask.chunk(2) init_mask, _ = mask.chunk(2)
else: else:
init_mask = mask init_mask = mask
...@@ -1079,6 +1132,18 @@ class StableDiffusionInpaintPipeline( ...@@ -1079,6 +1132,18 @@ class StableDiffusionInpaintPipeline(
latents = (1 - init_mask) * init_latents_proper + init_mask * latents latents = (1 - init_mask) * init_latents_proper + init_mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
mask = callback_outputs.pop("mask", mask)
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -478,19 +478,30 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -478,19 +478,30 @@ class StableDiffusionInpaintPipelineLegacy(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self,
prompt,
strength,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
): ):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment