Unverified Commit f1484b81 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Utils] Add deprecate function and move testing_utils under utils (#659)

* [Utils] Add deprecate function

* up

* up

* uP

* up

* up

* up

* up

* uP

* up

* fix

* up

* move to deprecation utils file

* fix

* fix

* fix more
parent 1070e1a3
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import torch import torch
from diffusers import UNet2DConditionModel, UNet2DModel from diffusers import UNet2DConditionModel, UNet2DModel
from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.utils import floats_tensor, slow, torch_device
from .test_modeling_common import ModelTesterMixin from .test_modeling_common import ModelTesterMixin
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
from diffusers.testing_utils import floats_tensor, torch_device from diffusers.utils import floats_tensor, torch_device
from .test_modeling_common import ModelTesterMixin from .test_modeling_common import ModelTesterMixin
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import torch import torch
from diffusers import VQModel from diffusers import VQModel
from diffusers.testing_utils import floats_tensor, torch_device from diffusers.utils import floats_tensor, torch_device
from .test_modeling_common import ModelTesterMixin from .test_modeling_common import ModelTesterMixin
......
...@@ -48,8 +48,7 @@ from diffusers import ( ...@@ -48,8 +48,7 @@ from diffusers import (
) )
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME
from PIL import Image from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
......
...@@ -18,8 +18,8 @@ import unittest ...@@ -18,8 +18,8 @@ import unittest
import torch import torch
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
from diffusers.testing_utils import slow
from diffusers.training_utils import set_seed from diffusers.training_utils import set_seed
from diffusers.utils.testing_utils import slow
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import __version__
from diffusers.utils import deprecate
class DeprecateTester(unittest.TestCase):
higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:])
lower_version = "0.0.1"
def test_deprecate_function_arg(self):
kwargs = {"deprecated_arg": 4}
with self.assertWarns(DeprecationWarning) as warning:
output = deprecate("deprecated_arg", self.higher_version, "message", take_from=kwargs)
assert output == 4
assert (
str(warning.warning)
== f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}."
" message"
)
def test_deprecate_function_arg_tuple(self):
kwargs = {"deprecated_arg": 4}
with self.assertWarns(DeprecationWarning) as warning:
output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs)
assert output == 4
assert (
str(warning.warning)
== f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}."
" message"
)
def test_deprecate_function_args(self):
kwargs = {"deprecated_arg_1": 4, "deprecated_arg_2": 8}
with self.assertWarns(DeprecationWarning) as warning:
output_1, output_2 = deprecate(
("deprecated_arg_1", self.higher_version, "Hey"),
("deprecated_arg_2", self.higher_version, "Hey"),
take_from=kwargs,
)
assert output_1 == 4
assert output_2 == 8
assert (
str(warning.warnings[0].message)
== "The `deprecated_arg_1` argument is deprecated and will be removed in version"
f" {self.higher_version}. Hey"
)
assert (
str(warning.warnings[1].message)
== "The `deprecated_arg_2` argument is deprecated and will be removed in version"
f" {self.higher_version}. Hey"
)
def test_deprecate_function_incorrect_arg(self):
kwargs = {"deprecated_arg": 4}
with self.assertRaises(TypeError) as error:
deprecate(("wrong_arg", self.higher_version, "message"), take_from=kwargs)
assert "test_deprecate_function_incorrect_arg in" in str(error.exception)
assert "line" in str(error.exception)
assert "got an unexpected keyword argument `deprecated_arg`" in str(error.exception)
def test_deprecate_arg_no_kwarg(self):
with self.assertWarns(DeprecationWarning) as warning:
deprecate(("deprecated_arg", self.higher_version, "message"))
assert (
str(warning.warning)
== f"`deprecated_arg` is deprecated and will be removed in version {self.higher_version}. message"
)
def test_deprecate_args_no_kwarg(self):
with self.assertWarns(DeprecationWarning) as warning:
deprecate(
("deprecated_arg_1", self.higher_version, "Hey"),
("deprecated_arg_2", self.higher_version, "Hey"),
)
assert (
str(warning.warnings[0].message)
== f"`deprecated_arg_1` is deprecated and will be removed in version {self.higher_version}. Hey"
)
assert (
str(warning.warnings[1].message)
== f"`deprecated_arg_2` is deprecated and will be removed in version {self.higher_version}. Hey"
)
def test_deprecate_class_obj(self):
class Args:
arg = 5
with self.assertWarns(DeprecationWarning) as warning:
arg = deprecate(("arg", self.higher_version, "message"), take_from=Args())
assert arg == 5
assert (
str(warning.warning)
== f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message"
)
def test_deprecate_class_objs(self):
class Args:
arg = 5
foo = 7
with self.assertWarns(DeprecationWarning) as warning:
arg_1, arg_2 = deprecate(
("arg", self.higher_version, "message"),
("foo", self.higher_version, "message"),
("does not exist", self.higher_version, "message"),
take_from=Args(),
)
assert arg_1 == 5
assert arg_2 == 7
assert (
str(warning.warning)
== f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message"
)
assert (
str(warning.warnings[0].message)
== f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message"
)
assert (
str(warning.warnings[1].message)
== f"The `foo` attribute is deprecated and will be removed in version {self.higher_version}. message"
)
def test_deprecate_incorrect_version(self):
kwargs = {"deprecated_arg": 4}
with self.assertRaises(ValueError) as error:
deprecate(("wrong_arg", self.lower_version, "message"), take_from=kwargs)
assert (
str(error.exception)
== "The deprecation tuple ('wrong_arg', '0.0.1', 'message') should be removed since diffusers' version"
f" {__version__} is >= {self.lower_version}"
)
def test_deprecate_incorrect_no_standard_warn(self):
with self.assertWarns(DeprecationWarning) as warning:
deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False)
assert str(warning.warning) == "This message is better!!!"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment