Unverified Commit 5c520972 authored by G.O.D's avatar G.O.D Committed by GitHub
Browse files

enable flux pipeline compatible with unipc and dpm-solver (#11908)



* Update pipeline_flux.py

have flux pipeline work with unipc/dpm schedulers

* clean code

* Update scheduling_dpmsolver_multistep.py

* Update scheduling_unipc_multistep.py

* Update pipeline_flux.py

* Update scheduling_deis_multistep.py

* Update scheduling_dpmsolver_singlestep.py

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarÁlvaro Somoza <asomoza@users.noreply.github.com>
parent aa14f090
......@@ -840,6 +840,8 @@ class FluxPipeline(
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
sigmas = None
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
......
......@@ -153,6 +153,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
flow_shift: Optional[float] = 1.0,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
......@@ -232,7 +234,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
......@@ -242,6 +246,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
......
......@@ -230,6 +230,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
......@@ -330,6 +332,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
):
"""
......@@ -345,6 +348,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
"""
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
......
......@@ -169,6 +169,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
......@@ -301,6 +303,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
):
"""
......@@ -316,6 +319,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
passed, `num_inference_steps` must be `None`.
"""
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
......
......@@ -212,6 +212,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
......@@ -298,7 +300,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
......@@ -309,6 +313,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment