Unverified Commit f1484b81 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Utils] Add deprecate function and move testing_utils under utils (#659)

* [Utils] Add deprecate function

* up

* up

* uP

* up

* up

* up

* up

* uP

* up

* fix

* up

* move to deprecation utils file

* fix

* fix

* fix more
parent 1070e1a3
...@@ -32,13 +32,13 @@ warnings.simplefilter(action="ignore", category=FutureWarning) ...@@ -32,13 +32,13 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_addoption(parser): def pytest_addoption(parser):
from diffusers.testing_utils import pytest_addoption_shared from diffusers.utils.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser) pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter): def pytest_terminal_summary(terminalreporter):
from diffusers.testing_utils import pytest_terminal_summary_main from diffusers.utils.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports") make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports: if make_reports:
......
...@@ -24,7 +24,7 @@ import unittest ...@@ -24,7 +24,7 @@ import unittest
from typing import List from typing import List
from accelerate.utils import write_basic_config from accelerate.utils import write_basic_config
from diffusers.testing_utils import slow from diffusers.utils import slow
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
......
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
...@@ -10,7 +9,7 @@ from ...configuration_utils import FrozenDict ...@@ -10,7 +9,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -59,15 +58,15 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -59,15 +58,15 @@ class StableDiffusionPipeline(DiffusionPipeline):
super().__init__() super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn( deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results" "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file", " file"
DeprecationWarning,
) )
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config) new_config = dict(scheduler.config)
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
......
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -12,7 +11,7 @@ from ...configuration_utils import FrozenDict ...@@ -12,7 +11,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -71,15 +70,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -71,15 +70,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
super().__init__() super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn( deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results" "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file", " file"
DeprecationWarning,
) )
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config) new_config = dict(scheduler.config)
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
......
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -13,7 +12,7 @@ from ...configuration_utils import FrozenDict ...@@ -13,7 +12,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -86,15 +85,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -86,15 +85,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn( deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results" "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file", " file"
DeprecationWarning,
) )
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config) new_config = dict(scheduler.config)
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
# and https://github.com/hojonathanho/diffusion # and https://github.com/hojonathanho/diffusion
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -24,7 +23,7 @@ import numpy as np ...@@ -24,7 +23,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -122,12 +121,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -122,12 +121,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
steps_offset: int = 0, steps_offset: int = 0,
**kwargs, **kwargs,
): ):
if "tensor_format" in kwargs: deprecate(
warnings.warn( "tensor_format",
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." "0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.", "If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning, take_from=kwargs,
) )
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
...@@ -175,17 +174,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -175,17 +174,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
""" """
deprecated_offset = deprecate(
offset = self.config.steps_offset "offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
)
if "offset" in kwargs: offset = deprecated_offset or self.config.steps_offset
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead.",
DeprecationWarning,
)
offset = kwargs["offset"]
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -23,7 +22,7 @@ import numpy as np ...@@ -23,7 +22,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -115,12 +114,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -115,12 +114,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True, clip_sample: bool = True,
**kwargs, **kwargs,
): ):
if "tensor_format" in kwargs: deprecate(
warnings.warn( "tensor_format",
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." "0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.", "If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning, take_from=kwargs,
) )
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -21,7 +20,7 @@ import numpy as np ...@@ -21,7 +20,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -89,12 +88,12 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -89,12 +88,12 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_max: float = 50, s_max: float = 50,
**kwargs, **kwargs,
): ):
if "tensor_format" in kwargs: deprecate(
warnings.warn( "tensor_format",
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." "0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.", "If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning, take_from=kwargs,
) )
# setable values # setable values
self.num_inference_steps: int = None self.num_inference_steps: int = None
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -22,7 +21,7 @@ import torch ...@@ -22,7 +21,7 @@ import torch
from scipy import integrate from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -77,12 +76,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -77,12 +76,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
**kwargs, **kwargs,
): ):
if "tensor_format" in kwargs: deprecate(
warnings.warn( "tensor_format",
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." "0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.", "If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning, take_from=kwargs,
) )
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
...@@ -102,12 +102,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -102,12 +102,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
steps_offset: int = 0, steps_offset: int = 0,
**kwargs, **kwargs,
): ):
if "tensor_format" in kwargs: deprecate(
warnings.warn( "tensor_format",
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." "0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.", "If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning, take_from=kwargs,
) )
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
...@@ -155,16 +155,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -155,16 +155,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
""" """
deprecated_offset = deprecate(
offset = self.config.steps_offset "offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
)
if "offset" in kwargs: offset = deprecated_offset or self.config.steps_offset
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead."
)
offset = kwargs["offset"]
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
......
...@@ -15,14 +15,13 @@ ...@@ -15,14 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
...@@ -78,12 +77,12 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -78,12 +77,12 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
correct_steps: int = 1, correct_steps: int = 1,
**kwargs, **kwargs,
): ):
if "tensor_format" in kwargs: deprecate(
warnings.warn( "tensor_format",
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." "0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.", "If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning, take_from=kwargs,
) )
# setable values # setable values
self.timesteps = None self.timesteps = None
...@@ -139,11 +138,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -139,11 +138,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
) )
def set_seed(self, seed): def set_seed(self, seed):
warnings.warn( deprecate("set_seed", "0.5.0", "Please consider passing a generator instead.")
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
" generator instead.",
DeprecationWarning,
)
torch.manual_seed(seed) torch.manual_seed(seed)
def step_pred( def step_pred(
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import math import math
import warnings
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -42,12 +42,12 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -42,12 +42,12 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs): def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
if "tensor_format" in kwargs: deprecate(
warnings.warn( "tensor_format",
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." "0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.", "If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning, take_from=kwargs,
) )
self.sigmas = None self.sigmas = None
self.discrete_sigmas = None self.discrete_sigmas = None
self.timesteps = None self.timesteps = None
......
...@@ -11,12 +11,11 @@ ...@@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from ..utils import BaseOutput from ..utils import BaseOutput, deprecate
SCHEDULER_CONFIG_NAME = "scheduler_config.json" SCHEDULER_CONFIG_NAME = "scheduler_config.json"
...@@ -44,10 +43,10 @@ class SchedulerMixin: ...@@ -44,10 +43,10 @@ class SchedulerMixin:
config_name = SCHEDULER_CONFIG_NAME config_name = SCHEDULER_CONFIG_NAME
def set_format(self, tensor_format="pt"): def set_format(self, tensor_format="pt"):
warnings.warn( deprecate(
"The method `set_format` is deprecated and will be removed in version `0.5.0`." "set_format",
"If you're running your code in PyTorch, you can safely remove this function as the schedulers" "0.5.0",
"are always in Pytorch", "If you're running your code in PyTorch, you can safely remove this function as the schedulers are always"
DeprecationWarning, " in Pytorch",
) )
return self return self
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
from .deprecation_utils import deprecate
from .import_utils import ( from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES, ENV_VARS_TRUE_VALUES,
...@@ -35,6 +36,7 @@ from .import_utils import ( ...@@ -35,6 +36,7 @@ from .import_utils import (
) )
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .testing_utils import floats_tensor, load_image, parse_flag_from_env, slow, torch_device
logger = get_logger(__name__) logger = get_logger(__name__)
......
import inspect
import warnings
from typing import Any, Dict, Optional, Union
from packaging import version
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True):
from .. import __version__
deprecated_kwargs = take_from
values = ()
if not isinstance(args[0], tuple):
args = (args,)
for attribute, version_name, message in args:
if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
raise ValueError(
f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'"
f" version {__version__} is >= {version_name}"
)
warning = None
if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs:
values += (deprecated_kwargs.pop(attribute),)
warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}."
elif hasattr(deprecated_kwargs, attribute):
values += (getattr(deprecated_kwargs, attribute),)
warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}."
elif deprecated_kwargs is None:
warning = f"`{attribute}` is deprecated and will be removed in version {version_name}."
if warning is not None:
warning = warning + " " if standard_warn else ""
warnings.warn(warning + message, DeprecationWarning)
if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
call_frame = inspect.getouterframes(inspect.currentframe())[1]
filename = call_frame.filename
line_number = call_frame.lineno
function = call_frame.function
key, value = next(iter(deprecated_kwargs.items()))
raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
if len(values) == 0:
return
elif len(values) == 1:
return values[0]
return values
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
Generic utilities Generic utilities
""" """
import warnings
from collections import OrderedDict from collections import OrderedDict
from dataclasses import fields from dataclasses import fields
from typing import Any, Tuple from typing import Any, Tuple
import numpy as np import numpy as np
from .deprecation_utils import deprecate
from .import_utils import is_torch_available from .import_utils import is_torch_available
...@@ -87,11 +87,7 @@ class BaseOutput(OrderedDict): ...@@ -87,11 +87,7 @@ class BaseOutput(OrderedDict):
if isinstance(k, str): if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()} inner_dict = {k: v for (k, v) in self.items()}
if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample": if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
warnings.warn( deprecate("samples", "0.6.0", "Please use `.images` or `'images'` instead.")
"The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
" `'images'` instead.",
DeprecationWarning,
)
return inner_dict["images"] return inner_dict["images"]
return inner_dict[k] return inner_dict[k]
else: else:
......
...@@ -31,13 +31,13 @@ warnings.simplefilter(action="ignore", category=FutureWarning) ...@@ -31,13 +31,13 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_addoption(parser): def pytest_addoption(parser):
from diffusers.testing_utils import pytest_addoption_shared from diffusers.utils.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser) pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter): def pytest_terminal_summary(terminalreporter):
from diffusers.testing_utils import pytest_terminal_summary_main from diffusers.utils.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports") make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports: if make_reports:
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from diffusers.models.attention import AttentionBlock, SpatialTransformer from diffusers.models.attention import AttentionBlock, SpatialTransformer
from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, Upsample2D from diffusers.models.resnet import Downsample2D, Upsample2D
from diffusers.testing_utils import torch_device from diffusers.utils import torch_device
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
......
...@@ -22,8 +22,8 @@ import numpy as np ...@@ -22,8 +22,8 @@ import numpy as np
import torch import torch
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
from diffusers.testing_utils import torch_device
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device
class ModelTesterMixin: class ModelTesterMixin:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment