Unverified Commit bdd16116 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Schedulers] Fix callback steps (#5261)

* fix all

* make fix copies

* make fix copies
parent c8b0f0eb
......@@ -932,7 +932,8 @@ class StableDiffusionRepaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
# call the callback, if provided
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
t_last = t
......
......@@ -771,7 +771,8 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
......
......@@ -389,7 +389,8 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
......
......@@ -432,7 +432,8 @@ class RDMPipeline(DiffusionPipeline):
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
......
......@@ -736,7 +736,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
......@@ -765,7 +765,8 @@ class AltDiffusionImg2ImgPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
......@@ -542,7 +542,8 @@ class AudioLDMPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 8. Post-processing
mel_spectrogram = self.decode_latents(latents)
......
......@@ -945,7 +945,8 @@ class AudioLDM2Pipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
self.maybe_free_model_hooks()
......
......@@ -1005,7 +1005,8 @@ class StableDiffusionControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
......
......@@ -1087,7 +1087,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
......
......@@ -1351,7 +1351,8 @@ class StableDiffusionControlNetInpaintPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
......
......@@ -1507,7 +1507,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
......
......@@ -1158,7 +1158,8 @@ class StableDiffusionXLControlNetPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
......
......@@ -1344,7 +1344,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
......
......@@ -382,7 +382,8 @@ class KandinskyPipeline(DiffusionPipeline):
).prev_sample
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
......@@ -475,7 +475,8 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
).prev_sample
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
......@@ -610,7 +610,8 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
).prev_sample
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
......@@ -244,7 +244,8 @@ class KandinskyV22Pipeline(DiffusionPipeline):
)[0]
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
......@@ -295,7 +295,8 @@ class KandinskyV22ControlnetPipeline(DiffusionPipeline):
)[0]
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
......@@ -355,7 +355,8 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
)[0]
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment