Unverified Commit e44fc75a authored by Dimitri Barbot's avatar Dimitri Barbot Committed by GitHub
Browse files

Update sdxl reference pipeline to latest sdxl pipeline (#9938)



* Update sdxl reference community pipeline

* Update README.md

Add example images.

* Style & quality

* Use example images from huggingface documentation-images repository

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent e47cc1fc
...@@ -2619,16 +2619,17 @@ for obj in range(bs): ...@@ -2619,16 +2619,17 @@ for obj in range(bs):
### Stable Diffusion XL Reference ### Stable Diffusion XL Reference
This pipeline uses the Reference. Refer to the [stable_diffusion_reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference). This pipeline uses the Reference. Refer to the [Stable Diffusion Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference) section for more information.
```py ```py
import torch import torch
from PIL import Image # from diffusers import DiffusionPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
from diffusers import DiffusionPipeline
from diffusers.schedulers import UniPCMultistepScheduler from diffusers.schedulers import UniPCMultistepScheduler
input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") from .stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg")
# pipe = DiffusionPipeline.from_pretrained( # pipe = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-1.0", # "stabilityai/stable-diffusion-xl-base-1.0",
...@@ -2646,7 +2647,7 @@ pipe = StableDiffusionXLReferencePipeline.from_pretrained( ...@@ -2646,7 +2647,7 @@ pipe = StableDiffusionXLReferencePipeline.from_pretrained(
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
result_img = pipe(ref_image=input_image, result_img = pipe(ref_image=input_image,
prompt="1girl", prompt="a dog",
num_inference_steps=20, num_inference_steps=20,
reference_attn=True, reference_attn=True,
reference_adain=True).images[0] reference_adain=True).images[0]
...@@ -2654,14 +2655,14 @@ result_img = pipe(ref_image=input_image, ...@@ -2654,14 +2655,14 @@ result_img = pipe(ref_image=input_image,
Reference Image Reference Image
![reference_image](https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png) ![reference_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg)
Output Image Output Image
`prompt: 1 girl` `prompt: a dog`
`reference_attn=True, reference_adain=True, num_inference_steps=20` `reference_attn=False, reference_adain=True, num_inference_steps=20`
![Output_image](https://github.com/zideliu/diffusers/assets/34944964/743848da-a215-48f9-ae39-b5e2ae49fb13) ![Output_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_adain_dog.png)
Reference Image Reference Image
![reference_image](https://github.com/huggingface/diffusers/assets/34944964/449bdab6-e744-4fb2-9620-d4068d9a741b) ![reference_image](https://github.com/huggingface/diffusers/assets/34944964/449bdab6-e744-4fb2-9620-d4068d9a741b)
...@@ -4696,4 +4697,4 @@ with torch.no_grad(): ...@@ -4696,4 +4697,4 @@ with torch.no_grad():
``` ```
In the folder examples/pixart there is also a script that can be used to train new models. In the folder examples/pixart there is also a script that can be used to train new models.
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training. Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
\ No newline at end of file
# Based on stable_diffusion_reference.py # Based on stable_diffusion_reference.py
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -7,28 +8,33 @@ import PIL.Image ...@@ -7,28 +8,33 @@ import PIL.Image
import torch import torch
from diffusers import StableDiffusionXLPipeline from diffusers import StableDiffusionXLPipeline
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import PipelineImageInput
from diffusers.models.attention import BasicTransformerBlock from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unets.unet_2d_blocks import ( from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
CrossAttnDownBlock2D, from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
CrossAttnUpBlock2D, from diffusers.utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging, replace_example_docstring
DownBlock2D,
UpBlock2D,
)
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.utils import PIL_INTERPOLATION, logging
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm # type: ignore
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
>>> import torch >>> import torch
>>> from diffusers import UniPCMultistepScheduler >>> from diffusers.schedulers import UniPCMultistepScheduler
>>> from diffusers.utils import load_image >>> from diffusers.utils import load_image
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") >>> input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg")
>>> pipe = StableDiffusionXLReferencePipeline.from_pretrained( >>> pipe = StableDiffusionXLReferencePipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/stable-diffusion-xl-base-1.0",
...@@ -38,7 +44,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -38,7 +44,7 @@ EXAMPLE_DOC_STRING = """
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
>>> result_img = pipe(ref_image=input_image, >>> result_img = pipe(ref_image=input_image,
prompt="1girl", prompt="a dog",
num_inference_steps=20, num_inference_steps=20,
reference_attn=True, reference_attn=True,
reference_adain=True).images[0] reference_adain=True).images[0]
...@@ -56,8 +62,6 @@ def torch_dfs(model: torch.nn.Module): ...@@ -56,8 +62,6 @@ def torch_dfs(model: torch.nn.Module):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" """
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
...@@ -72,33 +76,102 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -72,33 +76,102 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg return noise_cfg
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def _default_height_width(self, height, width, image): def retrieve_timesteps(
# NOTE: It is possible that a list of images have different scheduler,
# dimensions for each image, so just checking the first image num_inference_steps: Optional[int] = None,
# is not _exactly_ correct, but it is simple. device: Optional[Union[str, torch.device]] = None,
while isinstance(image, list): timesteps: Optional[List[int]] = None,
image = image[0] sigmas: Optional[List[float]] = None,
**kwargs,
if height is None: ):
if isinstance(image, PIL.Image.Image): r"""
height = image.height Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
elif isinstance(image, torch.Tensor): custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
height = image.shape[2]
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None: class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
if isinstance(image, PIL.Image.Image): def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
width = image.width refimage = refimage.to(device=device)
elif isinstance(image, torch.Tensor): if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
width = image.shape[3] self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype:
refimage = refimage.to(dtype=self.vae.dtype)
# encode the mask image into latents space so we can concatenate it to the latents
if isinstance(generator, list):
ref_image_latents = [
self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(batch_size)
]
ref_image_latents = torch.cat(ref_image_latents, dim=0)
else:
ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
# duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
if ref_image_latents.shape[0] < batch_size:
if not batch_size % ref_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
width = (width // 8) * 8 ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
return height, width # aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
return ref_image_latents
def prepare_image( def prepare_ref_image(
self, self,
image, image,
width, width,
...@@ -151,41 +224,42 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -151,41 +224,42 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
return image return image
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): def check_ref_inputs(
refimage = refimage.to(device=device) self,
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: ref_image,
self.upcast_vae() reference_guidance_start,
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) reference_guidance_end,
if refimage.dtype != self.vae.dtype: style_fidelity,
refimage = refimage.to(dtype=self.vae.dtype) reference_attn,
# encode the mask image into latents space so we can concatenate it to the latents reference_adain,
if isinstance(generator, list): ):
ref_image_latents = [ ref_image_is_pil = isinstance(ref_image, PIL.Image.Image)
self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i]) ref_image_is_tensor = isinstance(ref_image, torch.Tensor)
for i in range(batch_size)
]
ref_image_latents = torch.cat(ref_image_latents, dim=0)
else:
ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
# duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method if not ref_image_is_pil and not ref_image_is_tensor:
if ref_image_latents.shape[0] < batch_size: raise TypeError(
if not batch_size % ref_image_latents.shape[0] == 0: f"ref image must be passed and be one of PIL image or torch tensor, but is {type(ref_image)}"
raise ValueError( )
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents if not reference_attn and not reference_adain:
raise ValueError("`reference_attn` or `reference_adain` must be True.")
# aligning device to prevent device errors when concating it with the latent model input if style_fidelity < 0.0:
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) raise ValueError(f"style fidelity: {style_fidelity} can't be smaller than 0.")
return ref_image_latents if style_fidelity > 1.0:
raise ValueError(f"style fidelity: {style_fidelity} can't be larger than 1.0.")
if reference_guidance_start >= reference_guidance_end:
raise ValueError(
f"reference guidance start: {reference_guidance_start} cannot be larger or equal to reference guidance end: {reference_guidance_end}."
)
if reference_guidance_start < 0.0:
raise ValueError(f"reference guidance start: {reference_guidance_start} can't be smaller than 0.")
if reference_guidance_end > 1.0:
raise ValueError(f"reference guidance end: {reference_guidance_end} can't be larger than 1.0.")
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
...@@ -194,6 +268,8 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -194,6 +268,8 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
denoising_end: Optional[float] = None, denoising_end: Optional[float] = None,
guidance_scale: float = 5.0, guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
...@@ -206,28 +282,220 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -206,28 +282,220 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None, original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None, target_size: Optional[Tuple[int, int]] = None,
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
attention_auto_machine_weight: float = 1.0, attention_auto_machine_weight: float = 1.0,
gn_auto_machine_weight: float = 1.0, gn_auto_machine_weight: float = 1.0,
reference_guidance_start: float = 0.0,
reference_guidance_end: float = 1.0,
style_fidelity: float = 0.5, style_fidelity: float = 0.5,
reference_attn: bool = True, reference_attn: bool = True,
reference_adain: bool = True, reference_adain: bool = True,
**kwargs,
): ):
assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True." r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
ref_image (`torch.Tensor`, `PIL.Image.Image`):
The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.Tensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
also be accepted as an image.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
To negatively condition the generation process based on a target image resolution. It should be as same
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
attention_auto_machine_weight (`float`):
Weight of using reference query for self attention's context.
If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
gn_auto_machine_weight (`float`):
Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
reference_guidance_start (`float`, *optional*, defaults to 0.0):
The percentage of total steps at which the reference ControlNet starts applying.
reference_guidance_end (`float`, *optional*, defaults to 1.0):
The percentage of total steps at which the reference ControlNet stops applying.
style_fidelity (`float`):
style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
elif style_fidelity=0.0, prompt more important, else balanced.
reference_attn (`bool`):
Whether to use reference query for self attention's context.
reference_adain (`bool`):
Whether to use reference adain.
Examples:
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 0. Default height and width to unet if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
# height, width = self._default_height_width(height, width, ref_image) callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor
original_size = original_size or (height, width) original_size = original_size or (height, width)
target_size = target_size or (height, width) target_size = target_size or (height, width)
...@@ -244,8 +512,27 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -244,8 +512,27 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
negative_prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs,
)
self.check_ref_inputs(
ref_image,
reference_guidance_start,
reference_guidance_end,
style_fidelity,
reference_attn,
reference_adain,
) )
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -256,15 +543,11 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -256,15 +543,11 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
) )
( (
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
...@@ -275,17 +558,19 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -275,17 +558,19 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
prompt_2=prompt_2, prompt_2=prompt_2,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2, negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=lora_scale,
clip_skip=self.clip_skip,
) )
# 4. Preprocess reference image # 4. Preprocess reference image
ref_image = self.prepare_image( ref_image = self.prepare_ref_image(
image=ref_image, image=ref_image,
width=width, width=width,
height=height, height=height,
...@@ -296,9 +581,9 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -296,9 +581,9 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
) )
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
timesteps = self.scheduler.timesteps )
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
...@@ -312,6 +597,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -312,6 +597,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
generator, generator,
latents, latents,
) )
# 7. Prepare reference latent variables # 7. Prepare reference latent variables
ref_image_latents = self.prepare_ref_latents( ref_image_latents = self.prepare_ref_latents(
ref_image, ref_image,
...@@ -319,13 +605,21 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -319,13 +605,21 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
prompt_embeds.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
do_classifier_free_guidance, self.do_classifier_free_guidance,
) )
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 9. Modify self attebtion and group norm # 8.1 Create tensor stating which reference controlnets to keep
reference_keeps = []
for i in range(len(timesteps)):
reference_keep = 1.0 - float(
i / len(timesteps) < reference_guidance_start or (i + 1) / len(timesteps) > reference_guidance_end
)
reference_keeps.append(reference_keep)
# 8.2 Modify self attention and group norm
MODE = "write" MODE = "write"
uc_mask = ( uc_mask = (
torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt) torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
...@@ -333,6 +627,8 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -333,6 +627,8 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
.bool() .bool()
) )
do_classifier_free_guidance = self.do_classifier_free_guidance
def hacked_basic_transformer_inner_forward( def hacked_basic_transformer_inner_forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -604,7 +900,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -604,7 +900,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
return hidden_states return hidden_states
def hacked_UpBlock2D_forward( def hacked_UpBlock2D_forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, *args, **kwargs
): ):
eps = 1e-6 eps = 1e-6
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
...@@ -684,7 +980,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -684,7 +980,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
module.var_bank = [] module.var_bank = []
module.gn_weight *= 2 module.gn_weight *= 2
# 10. Prepare added time ids & embeddings # 9. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None: if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
...@@ -698,62 +994,101 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -698,62 +994,101 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
dtype=prompt_embeds.dtype, dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim, text_encoder_projection_dim=text_encoder_projection_dim,
) )
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device) prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device) add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 11. Denoising loop if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
# 10. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 10.1 Apply denoising_end # 10.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: if (
self.denoising_end is not None
and isinstance(self.denoising_end, float)
and self.denoising_end > 0
and self.denoising_end < 1
):
discrete_timestep_cutoff = int( discrete_timestep_cutoff = int(
round( round(
self.scheduler.config.num_train_timesteps self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps) - (self.denoising_end * self.scheduler.config.num_train_timesteps)
) )
) )
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps] timesteps = timesteps[:num_inference_steps]
# 11. Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
added_cond_kwargs["image_embeds"] = image_embeds
# ref only part # ref only part
noise = randn_tensor( if reference_keeps[i] > 0:
ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype noise = randn_tensor(
) ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
ref_xt = self.scheduler.add_noise( )
ref_image_latents, ref_xt = self.scheduler.add_noise(
noise, ref_image_latents,
t.reshape( noise,
1, t.reshape(
), 1,
) ),
ref_xt = self.scheduler.scale_model_input(ref_xt, t) )
ref_xt = self.scheduler.scale_model_input(ref_xt, t)
MODE = "write"
MODE = "write"
self.unet( self.unet(
ref_xt, ref_xt,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
) )
# predict the noise residual # predict the noise residual
MODE = "read" MODE = "read"
...@@ -761,22 +1096,44 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -761,22 +1096,44 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0: if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
...@@ -785,6 +1142,9 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -785,6 +1142,9 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
...@@ -792,25 +1152,43 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -792,25 +1152,43 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
if needs_upcasting: if needs_upcasting:
self.upcast_vae() self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed # cast back to fp16 if needed
if needs_upcasting: if needs_upcasting:
self.vae.to(dtype=torch.float16) self.vae.to(dtype=torch.float16)
else: else:
image = latents image = latents
return StableDiffusionXLPipelineOutput(images=image)
# apply watermark if available if not output_type == "latent":
if self.watermark is not None: # apply watermark if available
image = self.watermark.apply_watermark(image) if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU # Offload all models
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.maybe_free_model_hooks()
self.final_offload_hook.offload()
if not return_dict: if not return_dict:
return (image,) return (image,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment