Unverified Commit 8b451eb6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix config prints and save, load of pipelines (#2849)

* [Config] Fix config prints and save, load

* Only use potential nn.Modules for dtype and device

* Correct vae image processor

* make sure in_channels is not accessed directly

* make sure in channels is only accessed via config

* Make sure schedulers only access config attributes

* Make sure to access config in SAG

* Fix vae processor and make style

* add tests

* uP

* make style

* Fix more naming issues

* Final fix with vae config

* change more
parent 83691967
......@@ -804,7 +804,7 @@ def main():
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
......
......@@ -707,7 +707,7 @@ def main():
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
......
......@@ -109,13 +109,6 @@ class ConfigMixin:
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way.
kwargs.pop("kwargs", None)
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
if not hasattr(self, "_internal_dict"):
internal_dict = kwargs
else:
......
......@@ -99,8 +99,8 @@ class VaeImageProcessor(ConfigMixin):
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
"""
w, h = images.size
w, h = (x - x % self.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample])
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample])
return images
def preprocess(
......@@ -119,7 +119,7 @@ class VaeImageProcessor(ConfigMixin):
)
if isinstance(image[0], PIL.Image.Image):
if self.do_resize:
if self.config.do_resize:
image = [self.resize(i) for i in image]
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
image = np.stack(image, axis=0) # to np
......@@ -129,23 +129,27 @@ class VaeImageProcessor(ConfigMixin):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
image = self.numpy_to_pt(image)
_, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}"
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}"
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
# expected range [0,1], normalize to [-1,1]
do_normalize = self.do_normalize
do_normalize = self.config.do_normalize
if image.min() < 0:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
......
......@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, apply_forward_hook
from ..utils import BaseOutput, apply_forward_hook, deprecate
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
......@@ -120,9 +120,19 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
@property
def block_out_channels(self):
deprecate(
"block_out_channels",
"1.0.0",
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
standard_warn=False,
)
return self.config.block_out_channels
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value
......
......@@ -19,7 +19,7 @@ import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
......@@ -190,6 +190,16 @@ class UNet1DModel(ModelMixin, ConfigMixin):
fc_dim=block_out_channels[-1] // 4,
)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
def forward(
self,
sample: torch.FloatTensor,
......
......@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
......@@ -215,6 +215,16 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
def forward(
self,
sample: torch.FloatTensor,
......
......@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from ..utils import BaseOutput, deprecate, logging
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
......@@ -412,6 +412,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
......
......@@ -646,7 +646,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
......@@ -121,17 +121,17 @@ class AudioDiffusionPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(steps)
step_generator = step_generator or generator
# For backwards compatibility
if type(self.unet.sample_size) == int:
self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size)
if type(self.unet.config.sample_size) == int:
self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
input_dims = self.get_input_dims()
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None:
noise = randn_tensor(
(
batch_size,
self.unet.in_channels,
self.unet.sample_size[0],
self.unet.sample_size[1],
self.unet.config.in_channels,
self.unet.config.sample_size[0],
self.unet.config.sample_size[1],
),
generator=generator,
device=self.device,
......@@ -158,7 +158,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
pixels_per_second = (
self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
)
mask_start = int(mask_start_secs * pixels_per_second)
mask_end = int(mask_end_secs * pixels_per_second)
......
......@@ -540,7 +540,7 @@ class AudioLDMPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_waveforms_per_prompt,
num_channels_latents,
......
......@@ -61,7 +61,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
to make generation deterministic.
audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
`sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
......@@ -73,27 +73,29 @@ class DanceDiffusionPipeline(DiffusionPipeline):
if audio_length_in_s is None:
audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate
sample_size = audio_length_in_s * self.unet.sample_rate
sample_size = audio_length_in_s * self.unet.config.sample_rate
down_scale_factor = 2 ** len(self.unet.up_blocks)
if sample_size < 3 * down_scale_factor:
raise ValueError(
f"{audio_length_in_s} is too small. Make sure it's bigger or equal to"
f" {3 * down_scale_factor / self.unet.sample_rate}."
f" {3 * down_scale_factor / self.unet.config.sample_rate}."
)
original_sample_size = int(sample_size)
if sample_size % down_scale_factor != 0:
sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor
sample_size = (
(audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1
) * down_scale_factor
logger.info(
f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled"
f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising"
f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled"
f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising"
" process."
)
sample_size = int(sample_size)
dtype = next(iter(self.unet.parameters())).dtype
shape = (batch_size, self.unet.in_channels, sample_size)
shape = (batch_size, self.unet.config.in_channels, sample_size)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
......@@ -79,10 +79,15 @@ class DDIMPipeline(DiffusionPipeline):
"""
# Sample gaussian noise to begin loop
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
......
......@@ -67,10 +67,15 @@ class DDPMPipeline(DiffusionPipeline):
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""
# Sample gaussian noise to begin loop
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
......
......@@ -135,7 +135,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0]
# get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
......@@ -112,7 +112,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
height, width = image.shape[-2:]
# in_channels should be 6: 3 for latents, 3 for low resolution image
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width)
latents_dtype = next(self.unet.parameters()).dtype
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
......
......@@ -73,7 +73,7 @@ class LDMPipeline(DiffusionPipeline):
"""
latents = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
generator=generator,
)
latents = latents.to(self.device)
......
......@@ -506,6 +506,21 @@ class DiffusionPipeline(ConfigMixin):
# set models
setattr(self, name, module)
def __setattr__(self, name: str, value: Any):
if hasattr(self, name) and hasattr(self.config, name):
# We need to overwrite the config if name exists in config
if isinstance(getattr(self.config, name), (tuple, list)):
if self.config[name][0] is not None:
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
else:
class_library_tuple = (None, None)
self.register_to_config(**{name: class_library_tuple})
else:
self.register_to_config(**{name: value})
super().__setattr__(name, value)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
......@@ -619,9 +634,11 @@ class DiffusionPipeline(ConfigMixin):
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
module_names, _, _ = self.extract_init_dict(dict(self.config))
module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)]
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for name in module_names.keys():
for name in module_names:
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
module.to(torch_device, torch_dtype)
......@@ -646,8 +663,10 @@ class DiffusionPipeline(ConfigMixin):
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)]
for name in module_names:
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
return module.device
......@@ -1420,6 +1439,8 @@ class DiffusionPipeline(ConfigMixin):
fn_recursive_set_mem_eff(child)
module_names, _, _ = self.extract_init_dict(dict(self.config))
module_names = [m for m in module_names if hasattr(self, m)]
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module):
......@@ -1451,6 +1472,8 @@ class DiffusionPipeline(ConfigMixin):
def set_attention_slice(self, slice_size: Optional[int]):
module_names, _, _ = self.extract_init_dict(dict(self.config))
module_names = [m for m in module_names if hasattr(self, m)]
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
......
......@@ -77,7 +77,7 @@ class PNDMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
generator=generator,
device=self.device,
)
......
......@@ -476,7 +476,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment