Unverified Commit bdd16116 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Schedulers] Fix callback steps (#5261)

* fix all

* make fix copies

* make fix copies
parent c8b0f0eb
...@@ -414,7 +414,8 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): ...@@ -414,7 +414,8 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
return latents.clone().detach() return latents.clone().detach()
@torch.no_grad() @torch.no_grad()
......
...@@ -18,7 +18,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -18,7 +18,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.outputs import BaseOutput from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -426,7 +426,10 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -426,7 +426,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) if not self.use_peft_backend:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -551,6 +554,10 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -551,6 +554,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents
...@@ -1367,7 +1374,8 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1367,7 +1374,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 9. Post-processing # 9. Post-processing
image = None image = None
......
...@@ -539,7 +539,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -539,7 +539,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -380,7 +380,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -380,7 +380,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -454,7 +454,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -454,7 +454,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
......
...@@ -352,7 +352,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -352,7 +352,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
).prev_sample ).prev_sample
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 10. Scale and decode the image latents with vq-vae # 10. Scale and decode the image latents with vq-vae
latents = self.vqgan.config.scale_factor * latents latents = self.vqgan.config.scale_factor * latents
......
...@@ -436,7 +436,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline): ...@@ -436,7 +436,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
).prev_sample ).prev_sample
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 10. Denormalize the latents # 10. Denormalize the latents
latents = latents * self.config.latent_mean - self.config.latent_std latents = latents * self.config.latent_mean - self.config.latent_std
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment