Unverified Commit a0520193 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add Scheduler.from_pretrained and better scheduler changing (#1286)



* add conversion script for vae

* uP

* uP

* more changes

* push

* up

* finish again

* up

* up

* up

* up

* finish

* up

* uP

* up

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* up

* up
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent db1cb0b1
......@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
model_id = "google/ddpm-ema-celebahq-256"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = RePaintScheduler.from_config(model_id)
scheduler = RePaintScheduler.from_pretrained(model_id)
repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device)
......
......@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
model_id = "google/ncsnpp-church-256"
model = UNet2DModel.from_pretrained(model_id)
scheduler = ScoreSdeVeScheduler.from_config(model_id)
scheduler = ScoreSdeVeScheduler.from_pretrained(model_id)
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
sde_ve.to(torch_device)
......
......@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((512, 512))
model_id = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(
model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16"
)
......@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((512, 512))
model_id = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None)
pipe.to(torch_device)
......
......@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_ddim(self):
ddim_scheduler = DDIMScheduler.from_config(
ddim_scheduler = DDIMScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
)
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
......@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_k_lms(self):
lms_scheduler = LMSDiscreteScheduler.from_config(
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
)
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
......
......@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
lms_scheduler = LMSDiscreteScheduler.from_config(
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
)
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
......
......@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
lms_scheduler = LMSDiscreteScheduler.from_config(
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx"
)
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
......
......@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_fast_ddim(self):
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
sd_pipe = sd_pipe.to(torch_device)
......@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
pipe.set_progress_bar_config(disable=None)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe.scheduler = scheduler
prompt = "a photograph of an astronaut riding a horse"
......
......@@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
scheduler=lms,
......@@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
model_id = "CompVis/stable-diffusion-v1-4"
ddim = DDIMScheduler.from_config(model_id, subfolder="scheduler")
ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
scheduler=ddim,
......@@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((768, 512))
model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16
)
......
......@@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
safety_checker=None,
......
......@@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
)
model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
scheduler=lms,
......
......@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import unittest
import diffusers
from diffusers import (
DDIMScheduler,
DDPMScheduler,
......@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin):
class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self):
with self.assertRaises(ValueError):
ConfigMixin.from_config("dummy_path")
ConfigMixin.load_config("dummy_path")
def test_register_to_config(self):
obj = SampleObject()
......@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
new_obj = SampleObject.from_config(tmpdirname)
new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname))
new_config = new_obj.config
# unfreeze configs
......@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase):
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
assert config == new_config
def test_save_load_from_different_config(self):
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
new_obj_1 = SampleObject2.from_config(tmpdirname)
# now save a config parameter that is not expected
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
data = json.load(f)
data["unexpected"] = True
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger_2:
new_obj_2 = SampleObject.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_3:
new_obj_3 = SampleObject2.from_config(tmpdirname)
assert new_obj_1.__class__ == SampleObject2
assert new_obj_2.__class__ == SampleObject
assert new_obj_3.__class__ == SampleObject2
assert cap_logger_1.out == ""
assert (
cap_logger_2.out
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.\n"
)
assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out
def test_save_load_compatible_schedulers(self):
SampleObject2._compatible_classes = ["SampleObject"]
SampleObject._compatible_classes = ["SampleObject2"]
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
setattr(diffusers, "SampleObject2", SampleObject2)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
# now save a config parameter that is expected by another class, but not origin class
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
data = json.load(f)
data["f"] = [0, 0]
data["unexpected"] = True
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger:
new_obj = SampleObject.from_config(tmpdirname)
assert new_obj.__class__ == SampleObject
assert (
cap_logger.out
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.\n"
)
def test_save_load_from_different_config_comp_schedulers(self):
SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"]
SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"]
SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"]
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
setattr(diffusers, "SampleObject2", SampleObject2)
setattr(diffusers, "SampleObject3", SampleObject3)
logger = logging.get_logger("diffusers.configuration_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
new_obj_1 = SampleObject.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_2:
new_obj_2 = SampleObject2.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_3:
new_obj_3 = SampleObject3.from_config(tmpdirname)
assert new_obj_1.__class__ == SampleObject
assert new_obj_2.__class__ == SampleObject2
assert new_obj_3.__class__ == SampleObject3
assert cap_logger_1.out == ""
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
def test_load_ddim_from_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
ddim = DDIMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
assert ddim.__class__ == DDIMScheduler
# no warning should be thrown
......@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
euler = EulerDiscreteScheduler.from_config(
euler = EulerDiscreteScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
......@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
euler = EulerAncestralDiscreteScheduler.from_config(
euler = EulerAncestralDiscreteScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
......@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
assert pndm.__class__ == PNDMScheduler
# no warning should be thrown
......@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
ddpm = DDPMScheduler.from_config(
ddpm = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
......@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase):
)
with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88)
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.predict_epsilon is False
......@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
dpm = DPMSolverMultistepScheduler.from_config(
dpm = DPMSolverMultistepScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
......
......@@ -130,7 +130,7 @@ class ModelTesterMixin:
expected_arg_names = ["sample", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_config(self):
def test_model_from_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
......@@ -140,8 +140,8 @@ class ModelTesterMixin:
# test if the model can be loaded from the config
# and has all the expected shape
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_config(tmpdirname)
new_model = self.model_class.from_config(tmpdirname)
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
new_model.eval()
......
......@@ -29,6 +29,10 @@ from diffusers import (
DDIMScheduler,
DDPMPipeline,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipelineLegacy,
......@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase):
assert image_img2img.shape == (1, 32, 32, 3)
assert image_text2img.shape == (1, 128, 128, 3)
def test_set_scheduler(self):
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
sd = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, DDIMScheduler)
sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, DDPMScheduler)
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, PNDMScheduler)
sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, LMSDiscreteScheduler)
sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, EulerDiscreteScheduler)
sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler)
sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
def test_set_scheduler_consistency(self):
unet = self.dummy_cond_unet
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
sd = StableDiffusionPipeline(
unet=unet,
scheduler=pndm,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
pndm_config = sd.scheduler.config
sd.scheduler = DDPMScheduler.from_config(pndm_config)
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
pndm_config_2 = sd.scheduler.config
pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config}
assert dict(pndm_config) == dict(pndm_config_2)
sd = StableDiffusionPipeline(
unet=unet,
scheduler=ddim,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
ddim_config = sd.scheduler.config
sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config)
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
ddim_config_2 = sd.scheduler.config
ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config}
assert dict(ddim_config) == dict(ddim_config_2)
@slow
class PipelineSlowTests(unittest.TestCase):
......@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
scheduler = DDIMScheduler.from_config(model_path)
scheduler = DDIMScheduler.from_pretrained(model_path)
pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......
......@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import tempfile
import unittest
from typing import Dict, List, Tuple
......@@ -21,6 +23,7 @@ import numpy as np
import torch
import torch.nn.functional as F
import diffusers
from diffusers import (
DDIMScheduler,
DDPMScheduler,
......@@ -32,13 +35,180 @@ from diffusers import (
PNDMScheduler,
ScoreSdeVeScheduler,
VQDiffusionScheduler,
logging,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import deprecate, torch_device
from diffusers.utils.testing_utils import CaptureLogger
torch.backends.cuda.matmul.allow_tf32 = False
class SchedulerObject(SchedulerMixin, ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
):
pass
class SchedulerObject2(SchedulerMixin, ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
f=[1, 3],
):
pass
class SchedulerObject3(SchedulerMixin, ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
f=[1, 3],
):
pass
class SchedulerBaseTests(unittest.TestCase):
def test_save_load_from_different_config(self):
obj = SchedulerObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SchedulerObject", SchedulerObject)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
config = SchedulerObject2.load_config(tmpdirname)
new_obj_1 = SchedulerObject2.from_config(config)
# now save a config parameter that is not expected
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
data = json.load(f)
data["unexpected"] = True
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger_2:
config = SchedulerObject.load_config(tmpdirname)
new_obj_2 = SchedulerObject.from_config(config)
with CaptureLogger(logger) as cap_logger_3:
config = SchedulerObject2.load_config(tmpdirname)
new_obj_3 = SchedulerObject2.from_config(config)
assert new_obj_1.__class__ == SchedulerObject2
assert new_obj_2.__class__ == SchedulerObject
assert new_obj_3.__class__ == SchedulerObject2
assert cap_logger_1.out == ""
assert (
cap_logger_2.out
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
" will"
" be ignored. Please verify your config.json configuration file.\n"
)
assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out
def test_save_load_compatible_schedulers(self):
SchedulerObject2._compatibles = ["SchedulerObject"]
SchedulerObject._compatibles = ["SchedulerObject2"]
obj = SchedulerObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SchedulerObject", SchedulerObject)
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
# now save a config parameter that is expected by another class, but not origin class
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
data = json.load(f)
data["f"] = [0, 0]
data["unexpected"] = True
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger:
config = SchedulerObject.load_config(tmpdirname)
new_obj = SchedulerObject.from_config(config)
assert new_obj.__class__ == SchedulerObject
assert (
cap_logger.out
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
" will"
" be ignored. Please verify your config.json configuration file.\n"
)
def test_save_load_from_different_config_comp_schedulers(self):
SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"]
SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"]
SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"]
obj = SchedulerObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SchedulerObject", SchedulerObject)
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
setattr(diffusers, "SchedulerObject3", SchedulerObject3)
logger = logging.get_logger("diffusers.configuration_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
config = SchedulerObject.load_config(tmpdirname)
new_obj_1 = SchedulerObject.from_config(config)
with CaptureLogger(logger) as cap_logger_2:
config = SchedulerObject2.load_config(tmpdirname)
new_obj_2 = SchedulerObject2.from_config(config)
with CaptureLogger(logger) as cap_logger_3:
config = SchedulerObject3.load_config(tmpdirname)
new_obj_3 = SchedulerObject3.from_config(config)
assert new_obj_1.__class__ == SchedulerObject
assert new_obj_2.__class__ == SchedulerObject2
assert new_obj_3.__class__ == SchedulerObject3
assert cap_logger_1.out == ""
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
class SchedulerCommonTest(unittest.TestCase):
scheduler_classes = ()
forward_default_kwargs = ()
......@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
......@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
......@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
......@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_compatibles(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
assert all(c is not None for c in scheduler.compatibles)
for comp_scheduler_cls in scheduler.compatibles:
comp_scheduler = comp_scheduler_cls.from_config(scheduler.config)
assert comp_scheduler is not None
new_scheduler = scheduler_class.from_config(comp_scheduler.config)
new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config}
scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config}
# make sure that configs are essentially identical
assert new_scheduler_config == dict(scheduler.config)
# make sure that only differences are for configs that are not in init
init_keys = inspect.signature(scheduler_class.__init__).parameters.keys()
assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set()
def test_from_pretrained(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_pretrained(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
assert scheduler.config == new_scheduler.config
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
......@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
......@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
......@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
......@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
......@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
output = scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
......@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
output = scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
......@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
......@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
......
......@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
......@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
......@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
......@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
......@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
......@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
......@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
# copy over dummy past residuals
new_state = new_state.replace(ets=dummy_past_residuals[:])
......@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment