Unverified Commit 8b451eb6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix config prints and save, load of pipelines (#2849)

* [Config] Fix config prints and save, load

* Only use potential nn.Modules for dtype and device

* Correct vae image processor

* make sure in_channels is not accessed directly

* make sure in channels is only accessed via config

* Make sure schedulers only access config attributes

* Make sure to access config in SAG

* Fix vae processor and make style

* add tests

* uP

* make style

* Fix more naming issues

* Final fix with vae config

* change more
parent 83691967
...@@ -247,7 +247,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -247,7 +247,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
latents_shape = ( latents_shape = (
batch_size, batch_size,
self.unet.in_channels, self.unet.config.in_channels,
height // self.vae_scale_factor, height // self.vae_scale_factor,
width // self.vae_scale_factor, width // self.vae_scale_factor,
) )
......
...@@ -283,7 +283,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): ...@@ -283,7 +283,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
latents_shape = ( latents_shape = (
batch_size, batch_size,
self.unet.in_channels, self.unet.config.in_channels,
height // self.vae_scale_factor, height // self.vae_scale_factor,
width // self.vae_scale_factor, width // self.vae_scale_factor,
) )
......
...@@ -268,7 +268,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -268,7 +268,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
latents_shape = ( latents_shape = (
batch_size, batch_size,
self.unet.in_channels, self.unet.config.in_channels,
height // self.vae_scale_factor, height // self.vae_scale_factor,
width // self.vae_scale_factor, width // self.vae_scale_factor,
) )
......
...@@ -649,7 +649,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -649,7 +649,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -855,7 +855,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -855,7 +855,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -910,7 +910,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -910,7 +910,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -358,7 +358,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -358,7 +358,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -561,7 +561,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -561,7 +561,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
sigmas = sigmas.to(prompt_embeds.dtype) sigmas = sigmas.to(prompt_embeds.dtype)
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -722,7 +722,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -722,7 +722,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -586,7 +586,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -586,7 +586,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -929,7 +929,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -929,7 +929,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# 5. Generate the inverted noise from the input image or any other image # 5. Generate the inverted noise from the input image or any other image
# generated from the input prompt. # generated from the input prompt.
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -595,7 +595,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -595,7 +595,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
...@@ -701,7 +701,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -701,7 +701,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
# Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
bh, hw1, hw2 = attn_map.shape bh, hw1, hw2 = attn_map.shape
b, latent_channel, latent_h, latent_w = original_latents.shape b, latent_channel, latent_h, latent_w = original_latents.shape
h = self.unet.attention_head_dim h = self.unet.config.attention_head_dim
if isinstance(h, list): if isinstance(h, list):
h = h[-1] h = h[-1]
......
...@@ -877,7 +877,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -877,7 +877,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 11. Prepare latent variables # 11. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
latents = self.prepare_latents( latents = self.prepare_latents(
shape=shape, shape=shape,
......
...@@ -772,7 +772,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -772,7 +772,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size=batch_size, batch_size=batch_size,
num_channels_latents=num_channels_latents, num_channels_latents=num_channels_latents,
......
...@@ -623,7 +623,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -623,7 +623,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -606,7 +606,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -606,7 +606,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -12,7 +12,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel ...@@ -12,7 +12,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import logging from ...utils import deprecate, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -504,6 +504,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -504,6 +504,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
) )
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
(
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use"
" `unet.config.in_channels` instead"
),
standard_warn=False,
)
return self.config.in_channels
@property @property
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
......
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput, deprecate, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
...@@ -167,6 +167,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -167,6 +167,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type self.variance_type = variance_type
@property
def num_train_timesteps(self):
deprecate(
"num_train_timesteps",
"1.0.0",
"Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`",
standard_warn=False,
)
return self.config.num_train_timesteps
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
......
...@@ -183,7 +183,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -183,7 +183,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
timesteps = ( timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1] .round()[::-1][:-1]
.copy() .copy()
.astype(np.int64) .astype(np.int64)
......
...@@ -193,7 +193,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -193,7 +193,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
timesteps = ( timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1] .round()[::-1][:-1]
.copy() .copy()
.astype(np.int64) .astype(np.int64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment