Unverified Commit f7dfcfd9 authored by M. Tolga Cangöz's avatar M. Tolga Cangöz Committed by GitHub
Browse files

[`IP-Adapter`] Fix IP-Adapter Support and Refactor Callback for...

[`IP-Adapter`] Fix IP-Adapter Support and Refactor Callback for `StableDiffusionPanoramaPipeline` (#7262)

* Add properties and `IPAdapterTesterMixin` tests for `StableDiffusionPanoramaPipeline`

* Update torch manual seed to use `torch.Generator(device=device)`

* Refactor 📞🔙 to support `callback_on_step_end`

* make fix-copies
parent 3c67864c
......@@ -13,7 +13,7 @@
import copy
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
......@@ -59,6 +59,66 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class StableDiffusionPanoramaPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin
):
......@@ -97,6 +157,7 @@ class StableDiffusionPanoramaPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
......@@ -461,10 +522,23 @@ class StableDiffusionPanoramaPipeline(
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def decode_latents_with_padding(self, latents, padding=8):
# Add padding to latents for circular inference
# padding is the number of latents to add on each side
# it would slightly increase the memory usage, but remove the boundary artifacts
def decode_latents_with_padding(self, latents: torch.Tensor, padding: int = 8) -> torch.Tensor:
"""
Decode the given latents with padding for circular inference.
Args:
latents (torch.Tensor): The input latents to decode.
padding (int, optional): The number of latents to add on each side for padding. Defaults to 8.
Returns:
torch.Tensor: The decoded image with padding removed.
Notes:
- The padding is added to remove boundary artifacts and improve the output quality.
- This would slightly increase the memory usage.
- The padding pixels are then removed from the decoded image.
"""
latents = 1 / self.vae.config.scaling_factor * latents
latents_left = latents[..., :padding]
latents_right = latents[..., -padding:]
......@@ -580,9 +654,62 @@ class StableDiffusionPanoramaPipeline(
latents = latents * self.scheduler.init_noise_sigma
return latents
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
# if panorama's height/width < window_size, num_blocks of height/width should return 1
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
) -> torch.FloatTensor:
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
w (`torch.Tensor`):
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
embedding_dim (`int`, *optional*, defaults to 512):
Dimension of the embeddings to generate.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
Data type of the generated embeddings.
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
def get_views(
self,
panorama_height: int,
panorama_width: int,
window_size: int = 64,
stride: int = 8,
circular_padding: bool = False,
) -> List[Tuple[int, int, int, int]]:
"""
Generates a list of views based on the given parameters.
Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113).
If panorama's height/width < window_size, num_blocks of height/width should return 1.
Args:
panorama_height (int): The height of the panorama.
panorama_width (int): The width of the panorama.
window_size (int, optional): The size of the window. Defaults to 64.
stride (int, optional): The stride value. Defaults to 8.
circular_padding (bool, optional): Whether to apply circular padding. Defaults to False.
Returns:
List[Tuple[int, int, int, int]]: A list of tuples representing the views. Each tuple contains
four integers representing the start and end coordinates of the window in the panorama.
"""
panorama_height /= 8
panorama_width /= 8
num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
......@@ -600,6 +727,34 @@ class StableDiffusionPanoramaPipeline(
views.append((h_start, h_end, w_start, w_end))
return views
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def clip_skip(self):
return self._clip_skip
@property
def do_classifier_free_guidance(self):
return False
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
......@@ -608,6 +763,7 @@ class StableDiffusionPanoramaPipeline(
height: Optional[int] = 512,
width: Optional[int] = 2048,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
view_batch_size: int = 1,
negative_prompt: Optional[Union[str, List[str]]] = None,
......@@ -621,11 +777,13 @@ class StableDiffusionPanoramaPipeline(
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
circular_padding: bool = False,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs: Any,
):
r"""
The call function to the pipeline for generation.
......@@ -641,6 +799,9 @@ class StableDiffusionPanoramaPipeline(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
The timesteps at which to generate the images. If not specified, then the default
timestep spacing strategy of the scheduler is used.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
......@@ -680,16 +841,12 @@ class StableDiffusionPanoramaPipeline(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescaling factor for the guidance embeddings. A value of 0.0 means no rescaling is applied.
circular_padding (`bool`, *optional*, defaults to `False`):
If set to `True`, circular padding is applied to ensure there are no stitching artifacts. Circular
padding allows the model to seamlessly generate a transition from the rightmost part of the image to
......@@ -697,6 +854,15 @@ class StableDiffusionPanoramaPipeline(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
......@@ -706,6 +872,22 @@ class StableDiffusionPanoramaPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
......@@ -721,8 +903,15 @@ class StableDiffusionPanoramaPipeline(
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......@@ -768,8 +957,7 @@ class StableDiffusionPanoramaPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
......@@ -802,12 +990,23 @@ class StableDiffusionPanoramaPipeline(
else None
)
# 7.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 8. Denoising loop
# Each denoising step also includes refinement of the latents with respect to the
# views.
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
count.zero_()
value.zero_()
......@@ -863,6 +1062,7 @@ class StableDiffusionPanoramaPipeline(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds_input,
timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
......@@ -872,6 +1072,12 @@ class StableDiffusionPanoramaPipeline(
noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
)
# compute the previous noisy sample x_t -> x_t-1
latents_denoised_batch = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
......@@ -901,6 +1107,16 @@ class StableDiffusionPanoramaPipeline(
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = torch.where(count > 0, value / count, value)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
......@@ -908,7 +1124,7 @@ class StableDiffusionPanoramaPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
if output_type != "latent":
if circular_padding:
image = self.decode_latents_with_padding(latents)
else:
......
......@@ -32,14 +32,16 @@ from diffusers import (
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
@skip_mps
class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
class StableDiffusionPanoramaPipelineFastTests(
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionPanoramaPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
......@@ -96,7 +98,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
return components
def get_dummy_inputs(self, device, seed=0):
generator = torch.manual_seed(seed)
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "a photo of the dolomites",
"generator": generator,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment