Unverified Commit 703307ef authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix config deprecation (#3129)



* Better deprecation message

* Better deprecation message

* Better doc string

* Fixes

* fix more

* fix more

* Improve __getattr__

* correct more

* fix more

* fix

* Improve more

* more improvements

* fix more

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* make style

* Fix all rest & add tests & remove old deprecation fns

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent ed8fd383
...@@ -372,9 +372,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): ...@@ -372,9 +372,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels num_channels_latents = self.decoder.config.in_channels
height = self.decoder.sample_size height = self.decoder.config.sample_size
width = self.decoder.sample_size width = self.decoder.config.sample_size
decoder_latents = self.prepare_latents( decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width), (batch_size, num_channels_latents, height, width),
...@@ -425,9 +425,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): ...@@ -425,9 +425,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2 channels = self.super_res_first.config.in_channels // 2
height = self.super_res_first.sample_size height = self.super_res_first.config.sample_size
width = self.super_res_first.sample_size width = self.super_res_first.config.sample_size
super_res_latents = self.prepare_latents( super_res_latents = self.prepare_latents(
(batch_size, channels, height, width), (batch_size, channels, height, width),
......
...@@ -452,9 +452,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline): ...@@ -452,9 +452,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels num_channels_latents = self.decoder.config.in_channels
height = self.decoder.sample_size height = self.decoder.config.sample_size
width = self.decoder.sample_size width = self.decoder.config.sample_size
decoder_latents = self.prepare_latents( decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width), (batch_size, num_channels_latents, height, width),
...@@ -505,9 +505,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline): ...@@ -505,9 +505,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2 channels = self.super_res_first.config.in_channels // 2
height = self.super_res_first.sample_size height = self.super_res_first.config.sample_size
width = self.super_res_first.sample_size width = self.super_res_first.config.sample_size
super_res_latents = self.prepare_latents( super_res_latents = self.prepare_latents(
(batch_size, channels, height, width), (batch_size, channels, height, width),
......
...@@ -118,6 +118,24 @@ class ConfigMixin: ...@@ -118,6 +118,24 @@ class ConfigMixin:
self._internal_dict = FrozenDict(internal_dict) self._internal_dict = FrozenDict(internal_dict)
def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
"""
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
is_attribute = name in self.__dict__
if is_in_config and not is_attribute:
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
return self._internal_dict[name]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
""" """
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, apply_forward_hook, deprecate from ..utils import BaseOutput, apply_forward_hook
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
...@@ -123,16 +123,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -123,16 +123,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25 self.tile_overlap_factor = 0.25
@property
def block_out_channels(self):
deprecate(
"block_out_channels",
"1.0.0",
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
standard_warn=False,
)
return self.config.block_out_channels
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)): if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import inspect import inspect
import os import os
from functools import partial from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, device from torch import Tensor, device
...@@ -32,6 +32,7 @@ from ..utils import ( ...@@ -32,6 +32,7 @@ from ..utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
_add_variant, _add_variant,
_get_model_file, _get_model_file,
deprecate,
is_accelerate_available, is_accelerate_available,
is_safetensors_available, is_safetensors_available,
is_torch_version, is_torch_version,
...@@ -156,6 +157,24 @@ class ModelMixin(torch.nn.Module): ...@@ -156,6 +157,24 @@ class ModelMixin(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
"""
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
is_attribute = name in self.__dict__
if is_in_config and not is_attribute:
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
return self._internal_dict[name]
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
return super().__getattr__(name)
@property @property
def is_gradient_checkpointing(self) -> bool: def is_gradient_checkpointing(self) -> bool:
""" """
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
...@@ -190,16 +190,6 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -190,16 +190,6 @@ class UNet1DModel(ModelMixin, ConfigMixin):
fc_dim=block_out_channels[-1] // 4, fc_dim=block_out_channels[-1] // 4,
) )
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
...@@ -216,16 +216,6 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -216,16 +216,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
......
...@@ -21,7 +21,7 @@ import torch.utils.checkpoint ...@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, deprecate, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
...@@ -447,16 +447,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -447,16 +447,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
) )
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
@property @property
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
......
...@@ -508,7 +508,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -508,7 +508,7 @@ class DiffusionPipeline(ConfigMixin):
setattr(self, name, module) setattr(self, name, module)
def __setattr__(self, name: str, value: Any): def __setattr__(self, name: str, value: Any):
if hasattr(self, name) and hasattr(self.config, name): if name in self.__dict__ and hasattr(self.config, name):
# We need to overwrite the config if name exists in config # We need to overwrite the config if name exists in config
if isinstance(getattr(self.config, name), (tuple, list)): if isinstance(getattr(self.config, name), (tuple, list)):
if value is not None and self.config[name][0] is not None: if value is not None and self.config[name][0] is not None:
...@@ -648,26 +648,25 @@ class DiffusionPipeline(ConfigMixin): ...@@ -648,26 +648,25 @@ class DiffusionPipeline(ConfigMixin):
) )
module_names, _ = self._get_signature_keys(self) module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)] modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for name in module_names: for module in modules:
module = getattr(self, name) module.to(torch_device, torch_dtype)
if isinstance(module, torch.nn.Module): if (
module.to(torch_device, torch_dtype) module.dtype == torch.float16
if ( and str(torch_device) in ["cpu"]
module.dtype == torch.float16 and not silence_dtype_warnings
and str(torch_device) in ["cpu"] and not is_offloaded
and not silence_dtype_warnings ):
and not is_offloaded logger.warning(
): "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
logger.warning( " is not recommended to move them to `cpu` as running them will fail. Please make"
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" " sure to use an accelerator to run the pipeline in inference, due to the lack of"
" is not recommended to move them to `cpu` as running them will fail. Please make" " support for`float16` operations on this device in PyTorch. Please, remove the"
" sure to use an accelerator to run the pipeline in inference, due to the lack of" " `torch_dtype=torch.float16` argument, or use another device for inference."
" support for`float16` operations on this device in PyTorch. Please, remove the" )
" `torch_dtype=torch.float16` argument, or use another device for inference."
)
return self return self
@property @property
...@@ -677,12 +676,12 @@ class DiffusionPipeline(ConfigMixin): ...@@ -677,12 +676,12 @@ class DiffusionPipeline(ConfigMixin):
`torch.device`: The torch device on which the pipeline is located. `torch.device`: The torch device on which the pipeline is located.
""" """
module_names, _ = self._get_signature_keys(self) module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)] modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
for module in modules:
return module.device
for name in module_names:
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
return module.device
return torch.device("cpu") return torch.device("cpu")
@classmethod @classmethod
...@@ -1451,13 +1450,12 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1451,13 +1450,12 @@ class DiffusionPipeline(ConfigMixin):
for child in module.children(): for child in module.children():
fn_recursive_set_mem_eff(child) fn_recursive_set_mem_eff(child)
module_names, _, _ = self.extract_init_dict(dict(self.config)) module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)] modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
for module_name in module_names: for module in modules:
module = getattr(self, module_name) fn_recursive_set_mem_eff(module)
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
...@@ -1484,10 +1482,9 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1484,10 +1482,9 @@ class DiffusionPipeline(ConfigMixin):
self.enable_attention_slicing(None) self.enable_attention_slicing(None)
def set_attention_slice(self, slice_size: Optional[int]): def set_attention_slice(self, slice_size: Optional[int]):
module_names, _, _ = self.extract_init_dict(dict(self.config)) module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)] modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")]
for module_name in module_names: for module in modules:
module = getattr(self, module_name) module.set_attention_slice(slice_size)
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size)
...@@ -441,7 +441,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): ...@@ -441,7 +441,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# Prepare latent variables # Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_videos_per_prompt, batch_size * num_videos_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -413,9 +413,9 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -413,9 +413,9 @@ class UnCLIPPipeline(DiffusionPipeline):
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels num_channels_latents = self.decoder.config.in_channels
height = self.decoder.sample_size height = self.decoder.config.sample_size
width = self.decoder.sample_size width = self.decoder.config.sample_size
decoder_latents = self.prepare_latents( decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width), (batch_size, num_channels_latents, height, width),
...@@ -466,9 +466,9 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -466,9 +466,9 @@ class UnCLIPPipeline(DiffusionPipeline):
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2 channels = self.super_res_first.config.in_channels // 2
height = self.super_res_first.sample_size height = self.super_res_first.config.sample_size
width = self.super_res_first.sample_size width = self.super_res_first.config.sample_size
super_res_latents = self.prepare_latents( super_res_latents = self.prepare_latents(
(batch_size, channels, height, width), (batch_size, channels, height, width),
......
...@@ -339,9 +339,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -339,9 +339,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels num_channels_latents = self.decoder.config.in_channels
height = self.decoder.sample_size height = self.decoder.config.sample_size
width = self.decoder.sample_size width = self.decoder.config.sample_size
if decoder_latents is None: if decoder_latents is None:
decoder_latents = self.prepare_latents( decoder_latents = self.prepare_latents(
...@@ -393,9 +393,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -393,9 +393,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2 channels = self.super_res_first.config.in_channels // 2
height = self.super_res_first.sample_size height = self.super_res_first.config.sample_size
width = self.super_res_first.sample_size width = self.super_res_first.config.sample_size
if super_res_latents is None: if super_res_latents is None:
super_res_latents = self.prepare_latents( super_res_latents = self.prepare_latents(
......
...@@ -18,7 +18,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel ...@@ -18,7 +18,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import deprecate, logging from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -544,19 +544,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -544,19 +544,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
) )
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
(
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use"
" `unet.config.in_channels` instead"
),
standard_warn=False,
)
return self.config.in_channels
@property @property
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
......
...@@ -533,7 +533,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -533,7 +533,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.image_unet.in_channels num_channels_latents = self.image_unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -378,7 +378,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -378,7 +378,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.image_unet.in_channels num_channels_latents = self.image_unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -452,7 +452,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -452,7 +452,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.image_unet.in_channels num_channels_latents = self.image_unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate, randn_tensor from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
...@@ -168,16 +168,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -168,16 +168,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type self.variance_type = variance_type
@property
def num_train_timesteps(self):
deprecate(
"num_train_timesteps",
"1.0.0",
"Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`",
standard_warn=False,
)
return self.config.num_train_timesteps
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
......
...@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Union ...@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Union
from packaging import version from packaging import version
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
from .. import __version__ from .. import __version__
deprecated_kwargs = take_from deprecated_kwargs = take_from
...@@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn ...@@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn
if warning is not None: if warning is not None:
warning = warning + " " if standard_warn else "" warning = warning + " " if standard_warn else ""
warnings.warn(warning + message, FutureWarning, stacklevel=2) warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel)
if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
call_frame = inspect.getouterframes(inspect.currentframe())[1] call_frame = inspect.getouterframes(inspect.currentframe())[1]
......
...@@ -26,8 +26,8 @@ from requests.exceptions import HTTPError ...@@ -26,8 +26,8 @@ from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device from diffusers.utils import logging, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
...@@ -155,6 +155,49 @@ class ModelTesterMixin: ...@@ -155,6 +155,49 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item() max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_getattr_is_correct(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
# save some things to test
model.dummy_attribute = 5
model.register_to_config(test_attribute=5)
logger = logging.get_logger("diffusers.models.modeling_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
assert hasattr(model, "dummy_attribute")
assert getattr(model, "dummy_attribute") == 5
assert model.dummy_attribute == 5
# no warning should be thrown
assert cap_logger.out == ""
logger = logging.get_logger("diffusers.models.modeling_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
assert hasattr(model, "save_pretrained")
fn = model.save_pretrained
fn_1 = getattr(model, "save_pretrained")
assert fn == fn_1
# no warning should be thrown
assert cap_logger.out == ""
# warning should be thrown
with self.assertWarns(FutureWarning):
assert model.test_attribute == 5
with self.assertWarns(FutureWarning):
assert getattr(model, "test_attribute") == 5
with self.assertRaises(AttributeError) as error:
model.does_not_exist
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
def test_from_save_pretrained_variant(self): def test_from_save_pretrained_variant(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
...@@ -293,16 +293,16 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -293,16 +293,16 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
prior_latents = pipe.prepare_latents( prior_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
) )
shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size) shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size)
decoder_latents = pipe.prepare_latents( decoder_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
) )
shape = ( shape = (
batch_size, batch_size,
super_res_first.in_channels // 2, super_res_first.config.in_channels // 2,
super_res_first.sample_size, super_res_first.config.sample_size,
super_res_first.sample_size, super_res_first.config.sample_size,
) )
super_res_latents = pipe.prepare_latents( super_res_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment