Unverified Commit 8b451eb6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix config prints and save, load of pipelines (#2849)

* [Config] Fix config prints and save, load

* Only use potential nn.Modules for dtype and device

* Correct vae image processor

* make sure in_channels is not accessed directly

* make sure in channels is only accessed via config

* Make sure schedulers only access config attributes

* Make sure to access config in SAG

* Fix vae processor and make style

* add tests

* uP

* make style

* Fix more naming issues

* Final fix with vae config

* change more
parent 83691967
......@@ -247,7 +247,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
latents_shape = (
batch_size,
self.unet.in_channels,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
......
......@@ -283,7 +283,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
latents_shape = (
batch_size,
self.unet.in_channels,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
......
......@@ -268,7 +268,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
latents_shape = (
batch_size,
self.unet.in_channels,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
......
......@@ -649,7 +649,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -855,7 +855,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -910,7 +910,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -358,7 +358,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -561,7 +561,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
sigmas = sigmas.to(prompt_embeds.dtype)
# 6. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -722,7 +722,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -586,7 +586,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -929,7 +929,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# 5. Generate the inverted noise from the input image or any other image
# generated from the input prompt.
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -595,7 +595,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......@@ -701,7 +701,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
# Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
bh, hw1, hw2 = attn_map.shape
b, latent_channel, latent_h, latent_w = original_latents.shape
h = self.unet.attention_head_dim
h = self.unet.config.attention_head_dim
if isinstance(h, list):
h = h[-1]
......
......@@ -877,7 +877,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps
# 11. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
latents = self.prepare_latents(
shape=shape,
......
......@@ -772,7 +772,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size=batch_size,
num_channels_latents=num_channels_latents,
......
......@@ -623,7 +623,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -606,7 +606,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -12,7 +12,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import logging
from ...utils import deprecate, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -504,6 +504,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
(
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use"
" `unet.config.in_channels` instead"
),
standard_warn=False,
)
return self.config.in_channels
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
......
......@@ -22,7 +22,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor
from ..utils import BaseOutput, deprecate, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
......@@ -167,6 +167,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type
@property
def num_train_timesteps(self):
deprecate(
"num_train_timesteps",
"1.0.0",
"Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`",
standard_warn=False,
)
return self.config.num_train_timesteps
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
......
......@@ -183,7 +183,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
......
......@@ -193,7 +193,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment