"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "0659196b561d964b36a7369f6b73b74fa429cb2e"
Unverified Commit f0c6d978 authored by Vladimir Mandic's avatar Vladimir Mandic Committed by GitHub
Browse files

flux: make scheduler config params optional (#10384)



* dont assume scheduler has optional config params

* make style, make fix-copies

* calculate_shift

* fix-copies, usage in pipelines

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent d006f076
...@@ -875,10 +875,10 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -875,10 +875,10 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -820,10 +820,10 @@ class RFInversionFluxPipeline( ...@@ -820,10 +820,10 @@ class RFInversionFluxPipeline(
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
...@@ -990,10 +990,10 @@ class RFInversionFluxPipeline( ...@@ -990,10 +990,10 @@ class RFInversionFluxPipeline(
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inversion_steps = retrieve_timesteps( timesteps, num_inversion_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -64,6 +64,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -64,6 +64,7 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift( def calculate_shift(
image_seq_len, image_seq_len,
base_seq_len: int = 256, base_seq_len: int = 256,
...@@ -755,10 +756,10 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi ...@@ -755,10 +756,10 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
image_seq_len = latents.shape[1] image_seq_len = latents.shape[1]
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -822,10 +822,10 @@ class FluxPipeline( ...@@ -822,10 +822,10 @@ class FluxPipeline(
image_seq_len = latents.shape[1] image_seq_len = latents.shape[1]
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -82,6 +82,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -82,6 +82,7 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift( def calculate_shift(
image_seq_len, image_seq_len,
base_seq_len: int = 256, base_seq_len: int = 256,
...@@ -798,10 +799,10 @@ class FluxControlPipeline( ...@@ -798,10 +799,10 @@ class FluxControlPipeline(
image_seq_len = latents.shape[1] image_seq_len = latents.shape[1]
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -807,10 +807,10 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin ...@@ -807,10 +807,10 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -984,10 +984,10 @@ class FluxControlInpaintPipeline( ...@@ -984,10 +984,10 @@ class FluxControlInpaintPipeline(
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -874,10 +874,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -874,10 +874,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
image_seq_len = latents.shape[1] image_seq_len = latents.shape[1]
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -862,10 +862,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -862,10 +862,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -1016,10 +1016,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -1016,10 +1016,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
) )
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -881,10 +881,10 @@ class FluxFillPipeline( ...@@ -881,10 +881,10 @@ class FluxFillPipeline(
image_seq_len = latents.shape[1] image_seq_len = latents.shape[1]
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -744,10 +744,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile ...@@ -744,10 +744,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -876,10 +876,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -876,10 +876,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -677,10 +677,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi ...@@ -677,10 +677,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift( mu = calculate_shift(
video_sequence_length, video_sequence_length,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -747,10 +747,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo ...@@ -747,10 +747,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift( mu = calculate_shift(
video_sequence_length, video_sequence_length,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -62,19 +62,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -62,19 +62,6 @@ EXAMPLE_DOC_STRING = """
""" """
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 # from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
if linear_steps is None: if linear_steps is None:
......
...@@ -1013,10 +1013,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -1013,10 +1013,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
) )
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
scheduler_kwargs["mu"] = mu scheduler_kwargs["mu"] = mu
elif mu is not None: elif mu is not None:
......
...@@ -943,10 +943,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -943,10 +943,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
) )
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
scheduler_kwargs["mu"] = mu scheduler_kwargs["mu"] = mu
elif mu is not None: elif mu is not None:
......
...@@ -1053,10 +1053,10 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -1053,10 +1053,10 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
) )
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.max_image_seq_len, self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.base_shift, self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.max_shift, self.scheduler.config.get("max_shift", 1.16),
) )
scheduler_kwargs["mu"] = mu scheduler_kwargs["mu"] = mu
elif mu is not None: elif mu is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment