"src/vscode:/vscode.git/clone" did not exist on "99540747b5d1ac977be3e55eea03ae63b2d63c80"
Unverified Commit 5c520972 authored by G.O.D's avatar G.O.D Committed by GitHub
Browse files

enable flux pipeline compatible with unipc and dpm-solver (#11908)



* Update pipeline_flux.py

have flux pipeline work with unipc/dpm schedulers

* clean code

* Update scheduling_dpmsolver_multistep.py

* Update scheduling_unipc_multistep.py

* Update pipeline_flux.py

* Update scheduling_deis_multistep.py

* Update scheduling_dpmsolver_singlestep.py

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarÁlvaro Somoza <asomoza@users.noreply.github.com>
parent aa14f090
...@@ -840,6 +840,8 @@ class FluxPipeline( ...@@ -840,6 +840,8 @@ class FluxPipeline(
# 5. Prepare timesteps # 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
sigmas = None
image_seq_len = latents.shape[1] image_seq_len = latents.shape[1]
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
......
...@@ -153,6 +153,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -153,6 +153,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
flow_shift: Optional[float] = 1.0, flow_shift: Optional[float] = 1.0,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
): ):
if self.config.use_beta_sigmas and not is_scipy_available(): if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
...@@ -232,7 +234,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -232,7 +234,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -242,6 +246,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -242,6 +246,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = ( timesteps = (
......
...@@ -230,6 +230,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -230,6 +230,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
rescale_betas_zero_snr: bool = False, rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
): ):
if self.config.use_beta_sigmas and not is_scipy_available(): if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
...@@ -330,6 +332,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -330,6 +332,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self, self,
num_inference_steps: int = None, num_inference_steps: int = None,
device: Union[str, torch.device] = None, device: Union[str, torch.device] = None,
mu: Optional[float] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
): ):
""" """
...@@ -345,6 +348,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -345,6 +348,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored. must be `None`, and `timestep_spacing` attribute will be ignored.
""" """
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None: if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None: if num_inference_steps is not None and timesteps is not None:
......
...@@ -169,6 +169,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -169,6 +169,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
): ):
if self.config.use_beta_sigmas and not is_scipy_available(): if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
...@@ -301,6 +303,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -301,6 +303,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self, self,
num_inference_steps: int = None, num_inference_steps: int = None,
device: Union[str, torch.device] = None, device: Union[str, torch.device] = None,
mu: Optional[float] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
): ):
""" """
...@@ -316,6 +319,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -316,6 +319,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
passed, `num_inference_steps` must be `None`. passed, `num_inference_steps` must be `None`.
""" """
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None: if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None: if num_inference_steps is not None and timesteps is not None:
......
...@@ -212,6 +212,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -212,6 +212,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
steps_offset: int = 0, steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
rescale_betas_zero_snr: bool = False, rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
): ):
if self.config.use_beta_sigmas and not is_scipy_available(): if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.") raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
...@@ -298,7 +300,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -298,7 +300,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -309,6 +313,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -309,6 +313,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = ( timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment