Unverified Commit f1484b81 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Utils] Add deprecate function and move testing_utils under utils (#659)

* [Utils] Add deprecate function

* up

* up

* uP

* up

* up

* up

* up

* uP

* up

* fix

* up

* move to deprecation utils file

* fix

* fix

* fix more
parent 1070e1a3
......@@ -32,13 +32,13 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_addoption(parser):
from diffusers.testing_utils import pytest_addoption_shared
from diffusers.utils.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter):
from diffusers.testing_utils import pytest_terminal_summary_main
from diffusers.utils.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
......
......@@ -24,7 +24,7 @@ import unittest
from typing import List
from accelerate.utils import write_basic_config
from diffusers.testing_utils import slow
from diffusers.utils import slow
logging.basicConfig(level=logging.DEBUG)
......
import inspect
import warnings
from typing import Callable, List, Optional, Union
import torch
......@@ -10,7 +9,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -59,15 +58,15 @@ class StableDiffusionPipeline(DiffusionPipeline):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
......
import inspect
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
......@@ -12,7 +11,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -71,15 +70,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
......
import inspect
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
......@@ -13,7 +12,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -86,15 +85,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
......
......@@ -16,7 +16,6 @@
# and https://github.com/hojonathanho/diffusion
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -24,7 +23,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
......@@ -122,12 +121,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
steps_offset: int = 0,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
......@@ -175,17 +174,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
offset = self.config.steps_offset
if "offset" in kwargs:
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead.",
DeprecationWarning,
)
offset = kwargs["offset"]
deprecated_offset = deprecate(
"offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
)
offset = deprecated_offset or self.config.steps_offset
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
......
......@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -23,7 +22,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
......@@ -115,12 +114,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -21,7 +20,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
......@@ -89,12 +88,12 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_max: float = 50,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
# setable values
self.num_inference_steps: int = None
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -22,7 +21,7 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
......@@ -77,12 +76,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
......
......@@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput
......@@ -102,12 +102,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
steps_offset: int = 0,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
......@@ -155,16 +155,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
offset = self.config.steps_offset
if "offset" in kwargs:
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead."
)
offset = kwargs["offset"]
deprecated_offset = deprecate(
"offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
)
offset = deprecated_offset or self.config.steps_offset
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
......
......@@ -15,14 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput
......@@ -78,12 +77,12 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
correct_steps: int = 1,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
# setable values
self.timesteps = None
......@@ -139,11 +138,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
)
def set_seed(self, seed):
warnings.warn(
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
" generator instead.",
DeprecationWarning,
)
deprecate("set_seed", "0.5.0", "Please consider passing a generator instead.")
torch.manual_seed(seed)
def step_pred(
......
......@@ -17,11 +17,11 @@
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import math
import warnings
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import SchedulerMixin
......@@ -42,12 +42,12 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
......
......@@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
import torch
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
......@@ -44,10 +43,10 @@ class SchedulerMixin:
config_name = SCHEDULER_CONFIG_NAME
def set_format(self, tensor_format="pt"):
warnings.warn(
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
"are always in Pytorch",
DeprecationWarning,
deprecate(
"set_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this function as the schedulers are always"
" in Pytorch",
)
return self
......@@ -15,6 +15,7 @@
import os
from .deprecation_utils import deprecate
from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES,
......@@ -35,6 +36,7 @@ from .import_utils import (
)
from .logging import get_logger
from .outputs import BaseOutput
from .testing_utils import floats_tensor, load_image, parse_flag_from_env, slow, torch_device
logger = get_logger(__name__)
......
import inspect
import warnings
from typing import Any, Dict, Optional, Union
from packaging import version
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True):
from .. import __version__
deprecated_kwargs = take_from
values = ()
if not isinstance(args[0], tuple):
args = (args,)
for attribute, version_name, message in args:
if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
raise ValueError(
f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'"
f" version {__version__} is >= {version_name}"
)
warning = None
if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs:
values += (deprecated_kwargs.pop(attribute),)
warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}."
elif hasattr(deprecated_kwargs, attribute):
values += (getattr(deprecated_kwargs, attribute),)
warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}."
elif deprecated_kwargs is None:
warning = f"`{attribute}` is deprecated and will be removed in version {version_name}."
if warning is not None:
warning = warning + " " if standard_warn else ""
warnings.warn(warning + message, DeprecationWarning)
if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
call_frame = inspect.getouterframes(inspect.currentframe())[1]
filename = call_frame.filename
line_number = call_frame.lineno
function = call_frame.function
key, value = next(iter(deprecated_kwargs.items()))
raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
if len(values) == 0:
return
elif len(values) == 1:
return values[0]
return values
......@@ -15,13 +15,13 @@
Generic utilities
"""
import warnings
from collections import OrderedDict
from dataclasses import fields
from typing import Any, Tuple
import numpy as np
from .deprecation_utils import deprecate
from .import_utils import is_torch_available
......@@ -87,11 +87,7 @@ class BaseOutput(OrderedDict):
if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()}
if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
warnings.warn(
"The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
" `'images'` instead.",
DeprecationWarning,
)
deprecate("samples", "0.6.0", "Please use `.images` or `'images'` instead.")
return inner_dict["images"]
return inner_dict[k]
else:
......
......@@ -31,13 +31,13 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_addoption(parser):
from diffusers.testing_utils import pytest_addoption_shared
from diffusers.utils.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter):
from diffusers.testing_utils import pytest_terminal_summary_main
from diffusers.utils.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
......
......@@ -22,7 +22,7 @@ import torch
from diffusers.models.attention import AttentionBlock, SpatialTransformer
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, Upsample2D
from diffusers.testing_utils import torch_device
from diffusers.utils import torch_device
torch.backends.cuda.matmul.allow_tf32 = False
......
......@@ -22,8 +22,8 @@ import numpy as np
import torch
from diffusers.modeling_utils import ModelMixin
from diffusers.testing_utils import torch_device
from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device
class ModelTesterMixin:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment