Unverified Commit 8b451eb6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix config prints and save, load of pipelines (#2849)

* [Config] Fix config prints and save, load

* Only use potential nn.Modules for dtype and device

* Correct vae image processor

* make sure in_channels is not accessed directly

* make sure in channels is only accessed via config

* Make sure schedulers only access config attributes

* Make sure to access config in SAG

* Fix vae processor and make style

* add tests

* uP

* make style

* Fix more naming issues

* Final fix with vae config

* change more
parent 83691967
...@@ -804,7 +804,7 @@ def main(): ...@@ -804,7 +804,7 @@ def main():
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
......
...@@ -707,7 +707,7 @@ def main(): ...@@ -707,7 +707,7 @@ def main():
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
......
...@@ -109,13 +109,6 @@ class ConfigMixin: ...@@ -109,13 +109,6 @@ class ConfigMixin:
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way. # or solve in a more general way.
kwargs.pop("kwargs", None) kwargs.pop("kwargs", None)
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
if not hasattr(self, "_internal_dict"): if not hasattr(self, "_internal_dict"):
internal_dict = kwargs internal_dict = kwargs
else: else:
......
...@@ -99,8 +99,8 @@ class VaeImageProcessor(ConfigMixin): ...@@ -99,8 +99,8 @@ class VaeImageProcessor(ConfigMixin):
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
""" """
w, h = images.size w, h = images.size
w, h = (x - x % self.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample]) images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample])
return images return images
def preprocess( def preprocess(
...@@ -119,7 +119,7 @@ class VaeImageProcessor(ConfigMixin): ...@@ -119,7 +119,7 @@ class VaeImageProcessor(ConfigMixin):
) )
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
if self.do_resize: if self.config.do_resize:
image = [self.resize(i) for i in image] image = [self.resize(i) for i in image]
image = [np.array(i).astype(np.float32) / 255.0 for i in image] image = [np.array(i).astype(np.float32) / 255.0 for i in image]
image = np.stack(image, axis=0) # to np image = np.stack(image, axis=0) # to np
...@@ -129,23 +129,27 @@ class VaeImageProcessor(ConfigMixin): ...@@ -129,23 +129,27 @@ class VaeImageProcessor(ConfigMixin):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
image = self.numpy_to_pt(image) image = self.numpy_to_pt(image)
_, _, height, width = image.shape _, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
raise ValueError( raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}" f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
) )
elif isinstance(image[0], torch.Tensor): elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, _, height, width = image.shape _, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
raise ValueError( raise ValueError(
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}" f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
) )
# expected range [0,1], normalize to [-1,1] # expected range [0,1], normalize to [-1,1]
do_normalize = self.do_normalize do_normalize = self.config.do_normalize
if image.min() < 0: if image.min() < 0:
warnings.warn( warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, apply_forward_hook from ..utils import BaseOutput, apply_forward_hook, deprecate
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
...@@ -120,9 +120,19 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -120,9 +120,19 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
if isinstance(self.config.sample_size, (list, tuple)) if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size else self.config.sample_size
) )
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1))) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25 self.tile_overlap_factor = 0.25
@property
def block_out_channels(self):
deprecate(
"block_out_channels",
"1.0.0",
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
standard_warn=False,
)
return self.config.block_out_channels
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)): if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, deprecate
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
...@@ -190,6 +190,16 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -190,6 +190,16 @@ class UNet1DModel(ModelMixin, ConfigMixin):
fc_dim=block_out_channels[-1] // 4, fc_dim=block_out_channels[-1] // 4,
) )
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, deprecate
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
...@@ -215,6 +215,16 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -215,6 +215,16 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
......
...@@ -20,7 +20,7 @@ import torch.utils.checkpoint ...@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, deprecate, logging
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
...@@ -412,6 +412,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -412,6 +412,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
) )
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
@property @property
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
......
...@@ -646,7 +646,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -646,7 +646,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -121,17 +121,17 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -121,17 +121,17 @@ class AudioDiffusionPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(steps) self.scheduler.set_timesteps(steps)
step_generator = step_generator or generator step_generator = step_generator or generator
# For backwards compatibility # For backwards compatibility
if type(self.unet.sample_size) == int: if type(self.unet.config.sample_size) == int:
self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size) self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
input_dims = self.get_input_dims() input_dims = self.get_input_dims()
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0]) self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None: if noise is None:
noise = randn_tensor( noise = randn_tensor(
( (
batch_size, batch_size,
self.unet.in_channels, self.unet.config.in_channels,
self.unet.sample_size[0], self.unet.config.sample_size[0],
self.unet.sample_size[1], self.unet.config.sample_size[1],
), ),
generator=generator, generator=generator,
device=self.device, device=self.device,
...@@ -158,7 +158,7 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -158,7 +158,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
pixels_per_second = ( pixels_per_second = (
self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
) )
mask_start = int(mask_start_secs * pixels_per_second) mask_start = int(mask_start_secs * pixels_per_second)
mask_end = int(mask_end_secs * pixels_per_second) mask_end = int(mask_end_secs * pixels_per_second)
......
...@@ -540,7 +540,7 @@ class AudioLDMPipeline(DiffusionPipeline): ...@@ -540,7 +540,7 @@ class AudioLDMPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_waveforms_per_prompt, batch_size * num_waveforms_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -61,7 +61,7 @@ class DanceDiffusionPipeline(DiffusionPipeline): ...@@ -61,7 +61,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
to make generation deterministic. to make generation deterministic.
audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`. `sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
...@@ -73,27 +73,29 @@ class DanceDiffusionPipeline(DiffusionPipeline): ...@@ -73,27 +73,29 @@ class DanceDiffusionPipeline(DiffusionPipeline):
if audio_length_in_s is None: if audio_length_in_s is None:
audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate
sample_size = audio_length_in_s * self.unet.sample_rate sample_size = audio_length_in_s * self.unet.config.sample_rate
down_scale_factor = 2 ** len(self.unet.up_blocks) down_scale_factor = 2 ** len(self.unet.up_blocks)
if sample_size < 3 * down_scale_factor: if sample_size < 3 * down_scale_factor:
raise ValueError( raise ValueError(
f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" f"{audio_length_in_s} is too small. Make sure it's bigger or equal to"
f" {3 * down_scale_factor / self.unet.sample_rate}." f" {3 * down_scale_factor / self.unet.config.sample_rate}."
) )
original_sample_size = int(sample_size) original_sample_size = int(sample_size)
if sample_size % down_scale_factor != 0: if sample_size % down_scale_factor != 0:
sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor sample_size = (
(audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1
) * down_scale_factor
logger.info( logger.info(
f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled" f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled"
f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising" f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising"
" process." " process."
) )
sample_size = int(sample_size) sample_size = int(sample_size)
dtype = next(iter(self.unet.parameters())).dtype dtype = next(iter(self.unet.parameters())).dtype
shape = (batch_size, self.unet.in_channels, sample_size) shape = (batch_size, self.unet.config.in_channels, sample_size)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -79,10 +79,15 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -79,10 +79,15 @@ class DDIMPipeline(DiffusionPipeline):
""" """
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
if isinstance(self.unet.sample_size, int): if isinstance(self.unet.config.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else: else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
......
...@@ -67,10 +67,15 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -67,10 +67,15 @@ class DDPMPipeline(DiffusionPipeline):
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
""" """
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
if isinstance(self.unet.sample_size, int): if isinstance(self.unet.config.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else: else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
if self.device.type == "mps": if self.device.type == "mps":
# randn does not work reproducibly on mps # randn does not work reproducibly on mps
......
...@@ -135,7 +135,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -135,7 +135,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0] prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0]
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -112,7 +112,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline): ...@@ -112,7 +112,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
height, width = image.shape[-2:] height, width = image.shape[-2:]
# in_channels should be 6: 3 for latents, 3 for low resolution image # in_channels should be 6: 3 for latents, 3 for low resolution image
latents_shape = (batch_size, self.unet.in_channels // 2, height, width) latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width)
latents_dtype = next(self.unet.parameters()).dtype latents_dtype = next(self.unet.parameters()).dtype
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
......
...@@ -73,7 +73,7 @@ class LDMPipeline(DiffusionPipeline): ...@@ -73,7 +73,7 @@ class LDMPipeline(DiffusionPipeline):
""" """
latents = randn_tensor( latents = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
generator=generator, generator=generator,
) )
latents = latents.to(self.device) latents = latents.to(self.device)
......
...@@ -506,6 +506,21 @@ class DiffusionPipeline(ConfigMixin): ...@@ -506,6 +506,21 @@ class DiffusionPipeline(ConfigMixin):
# set models # set models
setattr(self, name, module) setattr(self, name, module)
def __setattr__(self, name: str, value: Any):
if hasattr(self, name) and hasattr(self.config, name):
# We need to overwrite the config if name exists in config
if isinstance(getattr(self.config, name), (tuple, list)):
if self.config[name][0] is not None:
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
else:
class_library_tuple = (None, None)
self.register_to_config(**{name: class_library_tuple})
else:
self.register_to_config(**{name: value})
super().__setattr__(name, value)
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
...@@ -619,9 +634,11 @@ class DiffusionPipeline(ConfigMixin): ...@@ -619,9 +634,11 @@ class DiffusionPipeline(ConfigMixin):
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
) )
module_names, _, _ = self.extract_init_dict(dict(self.config)) module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)]
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for name in module_names.keys(): for name in module_names:
module = getattr(self, name) module = getattr(self, name)
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
module.to(torch_device, torch_dtype) module.to(torch_device, torch_dtype)
...@@ -646,8 +663,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -646,8 +663,10 @@ class DiffusionPipeline(ConfigMixin):
Returns: Returns:
`torch.device`: The torch device on which the pipeline is located. `torch.device`: The torch device on which the pipeline is located.
""" """
module_names, _, _ = self.extract_init_dict(dict(self.config)) module_names, _ = self._get_signature_keys(self)
for name in module_names.keys(): module_names = [m for m in module_names if hasattr(self, m)]
for name in module_names:
module = getattr(self, name) module = getattr(self, name)
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
return module.device return module.device
...@@ -1420,6 +1439,8 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1420,6 +1439,8 @@ class DiffusionPipeline(ConfigMixin):
fn_recursive_set_mem_eff(child) fn_recursive_set_mem_eff(child)
module_names, _, _ = self.extract_init_dict(dict(self.config)) module_names, _, _ = self.extract_init_dict(dict(self.config))
module_names = [m for m in module_names if hasattr(self, m)]
for module_name in module_names: for module_name in module_names:
module = getattr(self, module_name) module = getattr(self, module_name)
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
...@@ -1451,6 +1472,8 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1451,6 +1472,8 @@ class DiffusionPipeline(ConfigMixin):
def set_attention_slice(self, slice_size: Optional[int]): def set_attention_slice(self, slice_size: Optional[int]):
module_names, _, _ = self.extract_init_dict(dict(self.config)) module_names, _, _ = self.extract_init_dict(dict(self.config))
module_names = [m for m in module_names if hasattr(self, m)]
for module_name in module_names: for module_name in module_names:
module = getattr(self, module_name) module = getattr(self, module_name)
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"): if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
......
...@@ -77,7 +77,7 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -77,7 +77,7 @@ class PNDMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = randn_tensor( image = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
generator=generator, generator=generator,
device=self.device, device=self.device,
) )
......
...@@ -476,7 +476,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): ...@@ -476,7 +476,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment