Unverified Commit 36369904 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] Fix FreeInit behaviour (#7410)

* fix freeinit impl

* fix progress bar

* fix progress bar and remove old code

* fix num_inference_steps==1 case for freeinit by atleast running 1 step when fast sampling enabled
parent 96135761
...@@ -792,7 +792,7 @@ class AnimateDiffPipeline( ...@@ -792,7 +792,7 @@ class AnimateDiffPipeline(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop # 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
......
...@@ -944,7 +944,7 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -944,7 +944,7 @@ class AnimateDiffVideoToVideoPipeline(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop # 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
......
...@@ -146,39 +146,40 @@ class FreeInitMixin: ...@@ -146,39 +146,40 @@ class FreeInitMixin:
): ):
if free_init_iteration == 0: if free_init_iteration == 0:
self._free_init_initial_noise = latents.detach().clone() self._free_init_initial_noise = latents.detach().clone()
return latents, self.scheduler.timesteps else:
latent_shape = latents.shape
latent_shape = latents.shape
free_init_filter_shape = (1, *latent_shape[1:])
free_init_filter_shape = (1, *latent_shape[1:]) free_init_freq_filter = self._get_free_init_freq_filter(
free_init_freq_filter = self._get_free_init_freq_filter( shape=free_init_filter_shape,
shape=free_init_filter_shape, device=device,
device=device, filter_type=self._free_init_method,
filter_type=self._free_init_method, order=self._free_init_order,
order=self._free_init_order, spatial_stop_frequency=self._free_init_spatial_stop_frequency,
spatial_stop_frequency=self._free_init_spatial_stop_frequency, temporal_stop_frequency=self._free_init_temporal_stop_frequency,
temporal_stop_frequency=self._free_init_temporal_stop_frequency, )
)
current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1 diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()
z_t = self.scheduler.add_noise(
z_t = self.scheduler.add_noise( original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device) ).to(dtype=torch.float32)
).to(dtype=torch.float32)
z_rand = randn_tensor(
z_rand = randn_tensor( shape=latent_shape,
shape=latent_shape, generator=generator,
generator=generator, device=device,
device=device, dtype=torch.float32,
dtype=torch.float32, )
) latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter) latents = latents.to(dtype)
latents = latents.to(dtype)
# Coarse-to-Fine Sampling for faster inference (can lead to lower quality) # Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
if self._free_init_use_fast_sampling: if self._free_init_use_fast_sampling:
num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1)) num_inference_steps = max(
1, int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
)
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
return latents, self.scheduler.timesteps return latents, self.scheduler.timesteps
...@@ -13,14 +13,12 @@ ...@@ -13,14 +13,12 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
import torch.fft as fft
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
...@@ -130,71 +128,6 @@ def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_sca ...@@ -130,71 +128,6 @@ def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_sca
return coef return coef
def _get_freeinit_freq_filter(
shape: Tuple[int, ...],
device: Union[str, torch.dtype],
filter_type: str,
order: float,
spatial_stop_frequency: float,
temporal_stop_frequency: float,
) -> torch.Tensor:
r"""Returns the FreeInit filter based on filter type and other input conditions."""
time, height, width = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
return mask
if filter_type == "butterworth":
def retrieve_mask(x):
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
elif filter_type == "gaussian":
def retrieve_mask(x):
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
elif filter_type == "ideal":
def retrieve_mask(x):
return 1 if x <= spatial_stop_frequency * 2 else 0
else:
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
for t in range(time):
for h in range(height):
for w in range(width):
d_square = (
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
+ (2 * h / height - 1) ** 2
+ (2 * w / width - 1) ** 2
)
mask[..., t, h, w] = retrieve_mask(d_square)
return mask.to(device)
def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor:
r"""Noise reinitialization."""
# FFT
x_freq = fft.fftn(x, dim=(-3, -2, -1))
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
# frequency mix
HPF = 1 - LPF
x_freq_low = x_freq * LPF
noise_freq_high = noise_freq * HPF
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
# IFFT
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
return x_mixed
@dataclass @dataclass
class PIAPipelineOutput(BaseOutput): class PIAPipelineOutput(BaseOutput):
r""" r"""
...@@ -202,9 +135,9 @@ class PIAPipelineOutput(BaseOutput): ...@@ -202,9 +135,9 @@ class PIAPipelineOutput(BaseOutput):
Args: Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`, Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
NumPy array of shape `(batch_size, num_frames, channels, height, width, NumPy array of shape `(batch_size, num_frames, channels, height, width,
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`. Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
""" """
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
...@@ -788,7 +721,8 @@ class PIAPipeline( ...@@ -788,7 +721,8 @@ class PIAPipeline(
The input image to be used for video generation. The input image to be used for video generation.
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated video. The height in pixels of the generated video.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
...@@ -979,8 +913,10 @@ class PIAPipeline( ...@@ -979,8 +913,10 @@ class PIAPipeline(
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
) )
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment