Unverified Commit 703307ef authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix config deprecation (#3129)



* Better deprecation message

* Better deprecation message

* Better doc string

* Fixes

* fix more

* fix more

* Improve __getattr__

* correct more

* fix more

* fix

* Improve more

* more improvements

* fix more

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* make style

* Fix all rest & add tests & remove old deprecation fns

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent ed8fd383
......@@ -379,16 +379,21 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
dtype = pipe.decoder.dtype
batch_size = 1
shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size)
shape = (
batch_size,
pipe.decoder.config.in_channels,
pipe.decoder.config.sample_size,
pipe.decoder.config.sample_size,
)
decoder_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (
batch_size,
pipe.super_res_first.in_channels // 2,
pipe.super_res_first.sample_size,
pipe.super_res_first.sample_size,
pipe.super_res_first.config.in_channels // 2,
pipe.super_res_first.config.sample_size,
pipe.super_res_first.config.sample_size,
)
super_res_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
......
......@@ -596,3 +596,47 @@ class SchedulerCommonTest(unittest.TestCase):
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
assert scheduler.betas.tolist() == new_scheduler.betas.tolist()
def test_getattr_is_correct(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
# save some things to test
scheduler.dummy_attribute = 5
scheduler.register_to_config(test_attribute=5)
logger = logging.get_logger("diffusers.configuration_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
assert hasattr(scheduler, "dummy_attribute")
assert getattr(scheduler, "dummy_attribute") == 5
assert scheduler.dummy_attribute == 5
# no warning should be thrown
assert cap_logger.out == ""
logger = logging.get_logger("diffusers.schedulers.schedulering_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
assert hasattr(scheduler, "save_pretrained")
fn = scheduler.save_pretrained
fn_1 = getattr(scheduler, "save_pretrained")
assert fn == fn_1
# no warning should be thrown
assert cap_logger.out == ""
# warning should be thrown
with self.assertWarns(FutureWarning):
assert scheduler.test_attribute == 5
with self.assertWarns(FutureWarning):
assert getattr(scheduler, "test_attribute") == 5
with self.assertRaises(AttributeError) as error:
scheduler.does_not_exist
assert str(error.exception) == f"'{type(scheduler).__name__}' object has no attribute 'does_not_exist'"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment