Unverified Commit c8bb1ff5 authored by Quentin Gallouédec's avatar Quentin Gallouédec Committed by GitHub
Browse files

Use HF Papers (#11567)



* Use HF Papers

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 799adf4a
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# DISCLAIMER: check https://arxiv.org/abs/2309.05019 # DISCLAIMER: check https://huggingface.co/papers/2309.05019
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
import math import math
...@@ -109,7 +109,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -109,7 +109,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`. Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`.
SA-Solver will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample SA-Solver will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample
from vanilla diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check from vanilla diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check
https://arxiv.org/abs/2309.05019 https://huggingface.co/papers/2309.05019
thresholding (`bool`, defaults to `False`): thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion. as Stable Diffusion.
...@@ -273,7 +273,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -273,7 +273,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = ( timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
...@@ -348,7 +348,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -348,7 +348,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
pixels from saturation at each step. We find that dynamic thresholding results in significantly better pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights." photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487 https://huggingface.co/papers/2205.11487
""" """
dtype = sample.dtype dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape batch_size, channels, *remaining_dims = sample.shape
......
...@@ -61,7 +61,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -61,7 +61,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
The variance exploding stochastic differential equation (SDE) scheduler. The variance exploding stochastic differential equation (SDE) scheduler.
For more information, see the original paper: https://arxiv.org/abs/2011.13456 For more information, see the original paper: https://huggingface.co/papers/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
......
...@@ -96,7 +96,7 @@ def betas_for_alpha_bar( ...@@ -96,7 +96,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
""" """
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args: Args:
...@@ -334,7 +334,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -334,7 +334,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
pixels from saturation at each step. We find that dynamic thresholding results in significantly better pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights." photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487 https://huggingface.co/papers/2205.11487
""" """
dtype = sample.dtype dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape batch_size, channels, *remaining_dims = sample.shape
......
...@@ -191,7 +191,7 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -191,7 +191,7 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
else: else:
beta = 1 - alpha_prod_t / alpha_prod_t_prev beta = 1 - alpha_prod_t / alpha_prod_t_prev
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://huggingface.co/papers/2006.11239)
# and sample from it to get previous sample # and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = beta_prod_t_prev / beta_prod_t * beta variance = beta_prod_t_prev / beta_prod_t * beta
...@@ -266,7 +266,7 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -266,7 +266,7 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
alpha = 1 - beta alpha = 1 - beta
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://huggingface.co/papers/2006.11239
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
...@@ -284,12 +284,12 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -284,12 +284,12 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
) )
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://huggingface.co/papers/2006.11239
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t
current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t # 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://huggingface.co/papers/2006.11239
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise # 6. Add noise
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info # DISCLAIMER: check https://huggingface.co/papers/2302.04867 and https://github.com/wl-zhao/UniPC for more info
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
import math import math
...@@ -78,7 +78,7 @@ def betas_for_alpha_bar( ...@@ -78,7 +78,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas): def rescale_zero_terminal_snr(betas):
""" """
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args: Args:
...@@ -308,7 +308,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -308,7 +308,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = ( timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
...@@ -429,7 +429,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -429,7 +429,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
pixels from saturation at each step. We find that dynamic thresholding results in significantly better pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights." photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487 https://huggingface.co/papers/2205.11487
""" """
dtype = sample.dtype dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape batch_size, channels, *remaining_dims = sample.shape
......
...@@ -149,9 +149,9 @@ def compute_dream_and_update_latents( ...@@ -149,9 +149,9 @@ def compute_dream_and_update_latents(
dream_detail_preservation: float = 1.0, dream_detail_preservation: float = 1.0,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
""" """
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210. Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from
DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more
forward step without gradients. efficient and accurate at the cost of an extra forward step without gradients.
Args: Args:
`unet`: The state unet to use to make a prediction. `unet`: The state unet to use to make a prediction.
...@@ -261,7 +261,7 @@ def compute_density_for_timestep_sampling( ...@@ -261,7 +261,7 @@ def compute_density_for_timestep_sampling(
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1. SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
""" """
if weighting_scheme == "logit_normal": if weighting_scheme == "logit_normal":
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator) u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
...@@ -280,7 +280,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): ...@@ -280,7 +280,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1. SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
""" """
if weighting_scheme == "sigma_sqrt": if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float() weighting = (sigmas**-2.0).float()
......
...@@ -91,7 +91,7 @@ def is_compiled_module(module) -> bool: ...@@ -91,7 +91,7 @@ def is_compiled_module(module) -> bool:
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). """Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).
This version of the method comes from here: This version of the method comes from here:
https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment