Unverified Commit bdd16116 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Schedulers] Fix callback steps (#5261)

* fix all

* make fix copies

* make fix copies
parent c8b0f0eb
...@@ -932,7 +932,8 @@ class StableDiffusionRepaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -932,7 +932,8 @@ class StableDiffusionRepaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
# call the callback, if provided # call the callback, if provided
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
t_last = t t_last = t
......
...@@ -771,7 +771,8 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -771,7 +771,8 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent": if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
......
...@@ -389,7 +389,8 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline): ...@@ -389,7 +389,8 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
......
...@@ -432,7 +432,8 @@ class RDMPipeline(DiffusionPipeline): ...@@ -432,7 +432,8 @@ class RDMPipeline(DiffusionPipeline):
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else: else:
......
...@@ -736,7 +736,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -736,7 +736,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -765,7 +765,8 @@ class AltDiffusionImg2ImgPipeline( ...@@ -765,7 +765,8 @@ class AltDiffusionImg2ImgPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -542,7 +542,8 @@ class AudioLDMPipeline(DiffusionPipeline): ...@@ -542,7 +542,8 @@ class AudioLDMPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 8. Post-processing # 8. Post-processing
mel_spectrogram = self.decode_latents(latents) mel_spectrogram = self.decode_latents(latents)
......
...@@ -945,7 +945,8 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -945,7 +945,8 @@ class AudioLDM2Pipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
......
...@@ -1005,7 +1005,8 @@ class StableDiffusionControlNetPipeline( ...@@ -1005,7 +1005,8 @@ class StableDiffusionControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
......
...@@ -1087,7 +1087,8 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1087,7 +1087,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
......
...@@ -1351,7 +1351,8 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1351,7 +1351,8 @@ class StableDiffusionControlNetInpaintPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
......
...@@ -1507,7 +1507,8 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1507,7 +1507,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
......
...@@ -1158,7 +1158,8 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1158,7 +1158,8 @@ class StableDiffusionXLControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# manually for max memory savings # manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
......
...@@ -1344,7 +1344,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1344,7 +1344,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet # If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings # manually for max memory savings
......
...@@ -382,7 +382,8 @@ class KandinskyPipeline(DiffusionPipeline): ...@@ -382,7 +382,8 @@ class KandinskyPipeline(DiffusionPipeline):
).prev_sample ).prev_sample
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
...@@ -475,7 +475,8 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): ...@@ -475,7 +475,8 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
).prev_sample ).prev_sample
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 7. post-processing # 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
...@@ -610,7 +610,8 @@ class KandinskyInpaintPipeline(DiffusionPipeline): ...@@ -610,7 +610,8 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
).prev_sample ).prev_sample
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
...@@ -244,7 +244,8 @@ class KandinskyV22Pipeline(DiffusionPipeline): ...@@ -244,7 +244,8 @@ class KandinskyV22Pipeline(DiffusionPipeline):
)[0] )[0]
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
...@@ -295,7 +295,8 @@ class KandinskyV22ControlnetPipeline(DiffusionPipeline): ...@@ -295,7 +295,8 @@ class KandinskyV22ControlnetPipeline(DiffusionPipeline):
)[0] )[0]
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
...@@ -355,7 +355,8 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): ...@@ -355,7 +355,8 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
)[0] )[0]
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment