Unverified Commit a0520193 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add Scheduler.from_pretrained and better scheduler changing (#1286)



* add conversion script for vae

* uP

* uP

* more changes

* push

* up

* finish again

* up

* up

* up

* up

* finish

* up

* uP

* up

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* up

* up
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent db1cb0b1
...@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase): ...@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
model_id = "google/ddpm-ema-celebahq-256" model_id = "google/ddpm-ema-celebahq-256"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = RePaintScheduler.from_config(model_id) scheduler = RePaintScheduler.from_pretrained(model_id)
repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device) repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device)
......
...@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase): ...@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
model_id = "google/ncsnpp-church-256" model_id = "google/ncsnpp-church-256"
model = UNet2DModel.from_pretrained(model_id) model = UNet2DModel.from_pretrained(model_id)
scheduler = ScoreSdeVeScheduler.from_config(model_id) scheduler = ScoreSdeVeScheduler.from_pretrained(model_id)
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
sde_ve.to(torch_device) sde_ve.to(torch_device)
......
...@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((512, 512)) init_image = init_image.resize((512, 512))
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler") scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained( pipe = CycleDiffusionPipeline.from_pretrained(
model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16" model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16"
) )
...@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((512, 512)) init_image = init_image.resize((512, 512))
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler") scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None) pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None)
pipe.to(torch_device) pipe.to(torch_device)
......
...@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_ddim(self): def test_inference_ddim(self):
ddim_scheduler = DDIMScheduler.from_config( ddim_scheduler = DDIMScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
) )
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
...@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_k_lms(self): def test_inference_k_lms(self):
lms_scheduler = LMSDiscreteScheduler.from_config( lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
) )
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
......
...@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
"/img2img/sketch-mountains-input.jpg" "/img2img/sketch-mountains-input.jpg"
) )
init_image = init_image.resize((768, 512)) init_image = init_image.resize((768, 512))
lms_scheduler = LMSDiscreteScheduler.from_config( lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
) )
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained( pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
......
...@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): ...@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png" "/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
) )
lms_scheduler = LMSDiscreteScheduler.from_config( lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx" "runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx"
) )
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained( pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
......
...@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_fast_ddim(self): def test_stable_diffusion_fast_ddim(self):
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler") scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler) sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device)
...@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
model_id = "CompVis/stable-diffusion-v1-1" model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device) pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe.scheduler = scheduler pipe.scheduler = scheduler
prompt = "a photograph of an astronaut riding a horse" prompt = "a photograph of an astronaut riding a horse"
......
...@@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
) )
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, model_id,
scheduler=lms, scheduler=lms,
...@@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
) )
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
ddim = DDIMScheduler.from_config(model_id, subfolder="scheduler") ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, model_id,
scheduler=ddim, scheduler=ddim,
...@@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((768, 512)) init_image = init_image.resize((768, 512))
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16 model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16
) )
......
...@@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): ...@@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
) )
model_id = "runwayml/stable-diffusion-inpainting" model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm) pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): ...@@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
) )
model_id = "runwayml/stable-diffusion-inpainting" model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, model_id,
safety_checker=None, safety_checker=None,
......
...@@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): ...@@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
) )
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, model_id,
scheduler=lms, scheduler=lms,
......
...@@ -13,12 +13,9 @@ ...@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os
import tempfile import tempfile
import unittest import unittest
import diffusers
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
...@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin): ...@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin):
class ConfigTester(unittest.TestCase): class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self): def test_load_not_from_mixin(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ConfigMixin.from_config("dummy_path") ConfigMixin.load_config("dummy_path")
def test_register_to_config(self): def test_register_to_config(self):
obj = SampleObject() obj = SampleObject()
...@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase): ...@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname) obj.save_config(tmpdirname)
new_obj = SampleObject.from_config(tmpdirname) new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname))
new_config = new_obj.config new_config = new_obj.config
# unfreeze configs # unfreeze configs
...@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase): ...@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase):
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
assert config == new_config assert config == new_config
def test_save_load_from_different_config(self):
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
new_obj_1 = SampleObject2.from_config(tmpdirname)
# now save a config parameter that is not expected
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
data = json.load(f)
data["unexpected"] = True
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger_2:
new_obj_2 = SampleObject.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_3:
new_obj_3 = SampleObject2.from_config(tmpdirname)
assert new_obj_1.__class__ == SampleObject2
assert new_obj_2.__class__ == SampleObject
assert new_obj_3.__class__ == SampleObject2
assert cap_logger_1.out == ""
assert (
cap_logger_2.out
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.\n"
)
assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out
def test_save_load_compatible_schedulers(self):
SampleObject2._compatible_classes = ["SampleObject"]
SampleObject._compatible_classes = ["SampleObject2"]
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
setattr(diffusers, "SampleObject2", SampleObject2)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
# now save a config parameter that is expected by another class, but not origin class
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
data = json.load(f)
data["f"] = [0, 0]
data["unexpected"] = True
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger:
new_obj = SampleObject.from_config(tmpdirname)
assert new_obj.__class__ == SampleObject
assert (
cap_logger.out
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.\n"
)
def test_save_load_from_different_config_comp_schedulers(self):
SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"]
SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"]
SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"]
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
setattr(diffusers, "SampleObject2", SampleObject2)
setattr(diffusers, "SampleObject3", SampleObject3)
logger = logging.get_logger("diffusers.configuration_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
new_obj_1 = SampleObject.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_2:
new_obj_2 = SampleObject2.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_3:
new_obj_3 = SampleObject3.from_config(tmpdirname)
assert new_obj_1.__class__ == SampleObject
assert new_obj_2.__class__ == SampleObject2
assert new_obj_3.__class__ == SampleObject3
assert cap_logger_1.out == ""
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
def test_load_ddim_from_pndm(self): def test_load_ddim_from_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils") logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") ddim = DDIMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
assert ddim.__class__ == DDIMScheduler assert ddim.__class__ == DDIMScheduler
# no warning should be thrown # no warning should be thrown
...@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase): ...@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils") logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
euler = EulerDiscreteScheduler.from_config( euler = EulerDiscreteScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
) )
...@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase): ...@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils") logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
euler = EulerAncestralDiscreteScheduler.from_config( euler = EulerAncestralDiscreteScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
) )
...@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase): ...@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils") logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") pndm = PNDMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
assert pndm.__class__ == PNDMScheduler assert pndm.__class__ == PNDMScheduler
# no warning should be thrown # no warning should be thrown
...@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase): ...@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils") logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
ddpm = DDPMScheduler.from_config( ddpm = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", "hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler", subfolder="scheduler",
predict_epsilon=False, predict_epsilon=False,
...@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase): ...@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase):
) )
with CaptureLogger(logger) as cap_logger_2: with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88) ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
assert ddpm.__class__ == DDPMScheduler assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.predict_epsilon is False assert ddpm.config.predict_epsilon is False
...@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase): ...@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils") logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
dpm = DPMSolverMultistepScheduler.from_config( dpm = DPMSolverMultistepScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
) )
......
...@@ -130,7 +130,7 @@ class ModelTesterMixin: ...@@ -130,7 +130,7 @@ class ModelTesterMixin:
expected_arg_names = ["sample", "timestep"] expected_arg_names = ["sample", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names) self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_config(self): def test_model_from_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -140,8 +140,8 @@ class ModelTesterMixin: ...@@ -140,8 +140,8 @@ class ModelTesterMixin:
# test if the model can be loaded from the config # test if the model can be loaded from the config
# and has all the expected shape # and has all the expected shape
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_config(tmpdirname) model.save_pretrained(tmpdirname)
new_model = self.model_class.from_config(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device) new_model.to(torch_device)
new_model.eval() new_model.eval()
......
...@@ -29,6 +29,10 @@ from diffusers import ( ...@@ -29,6 +29,10 @@ from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMPipeline, DDPMPipeline,
DDPMScheduler, DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipelineLegacy, StableDiffusionInpaintPipelineLegacy,
...@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase): ...@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase):
assert image_img2img.shape == (1, 32, 32, 3) assert image_img2img.shape == (1, 32, 32, 3)
assert image_text2img.shape == (1, 128, 128, 3) assert image_text2img.shape == (1, 128, 128, 3)
def test_set_scheduler(self):
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
sd = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, DDIMScheduler)
sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, DDPMScheduler)
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, PNDMScheduler)
sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, LMSDiscreteScheduler)
sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, EulerDiscreteScheduler)
sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler)
sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
def test_set_scheduler_consistency(self):
unet = self.dummy_cond_unet
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
sd = StableDiffusionPipeline(
unet=unet,
scheduler=pndm,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
pndm_config = sd.scheduler.config
sd.scheduler = DDPMScheduler.from_config(pndm_config)
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
pndm_config_2 = sd.scheduler.config
pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config}
assert dict(pndm_config) == dict(pndm_config_2)
sd = StableDiffusionPipeline(
unet=unet,
scheduler=ddim,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
ddim_config = sd.scheduler.config
sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config)
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
ddim_config_2 = sd.scheduler.config
ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config}
assert dict(ddim_config) == dict(ddim_config_2)
@slow @slow
class PipelineSlowTests(unittest.TestCase): class PipelineSlowTests(unittest.TestCase):
...@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_output_format(self): def test_output_format(self):
model_path = "google/ddpm-cifar10-32" model_path = "google/ddpm-cifar10-32"
scheduler = DDIMScheduler.from_config(model_path) scheduler = DDIMScheduler.from_pretrained(model_path)
pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler) pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
import json
import os
import tempfile import tempfile
import unittest import unittest
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -21,6 +23,7 @@ import numpy as np ...@@ -21,6 +23,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import diffusers
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
...@@ -32,13 +35,180 @@ from diffusers import ( ...@@ -32,13 +35,180 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
VQDiffusionScheduler, VQDiffusionScheduler,
logging,
) )
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import deprecate, torch_device from diffusers.utils import deprecate, torch_device
from diffusers.utils.testing_utils import CaptureLogger
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class SchedulerObject(SchedulerMixin, ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
):
pass
class SchedulerObject2(SchedulerMixin, ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
f=[1, 3],
):
pass
class SchedulerObject3(SchedulerMixin, ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
f=[1, 3],
):
pass
class SchedulerBaseTests(unittest.TestCase):
def test_save_load_from_different_config(self):
obj = SchedulerObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SchedulerObject", SchedulerObject)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
config = SchedulerObject2.load_config(tmpdirname)
new_obj_1 = SchedulerObject2.from_config(config)
# now save a config parameter that is not expected
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
data = json.load(f)
data["unexpected"] = True
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger_2:
config = SchedulerObject.load_config(tmpdirname)
new_obj_2 = SchedulerObject.from_config(config)
with CaptureLogger(logger) as cap_logger_3:
config = SchedulerObject2.load_config(tmpdirname)
new_obj_3 = SchedulerObject2.from_config(config)
assert new_obj_1.__class__ == SchedulerObject2
assert new_obj_2.__class__ == SchedulerObject
assert new_obj_3.__class__ == SchedulerObject2
assert cap_logger_1.out == ""
assert (
cap_logger_2.out
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
" will"
" be ignored. Please verify your config.json configuration file.\n"
)
assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out
def test_save_load_compatible_schedulers(self):
SchedulerObject2._compatibles = ["SchedulerObject"]
SchedulerObject._compatibles = ["SchedulerObject2"]
obj = SchedulerObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SchedulerObject", SchedulerObject)
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
# now save a config parameter that is expected by another class, but not origin class
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
data = json.load(f)
data["f"] = [0, 0]
data["unexpected"] = True
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger:
config = SchedulerObject.load_config(tmpdirname)
new_obj = SchedulerObject.from_config(config)
assert new_obj.__class__ == SchedulerObject
assert (
cap_logger.out
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
" will"
" be ignored. Please verify your config.json configuration file.\n"
)
def test_save_load_from_different_config_comp_schedulers(self):
SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"]
SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"]
SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"]
obj = SchedulerObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SchedulerObject", SchedulerObject)
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
setattr(diffusers, "SchedulerObject3", SchedulerObject3)
logger = logging.get_logger("diffusers.configuration_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
config = SchedulerObject.load_config(tmpdirname)
new_obj_1 = SchedulerObject.from_config(config)
with CaptureLogger(logger) as cap_logger_2:
config = SchedulerObject2.load_config(tmpdirname)
new_obj_2 = SchedulerObject2.from_config(config)
with CaptureLogger(logger) as cap_logger_3:
config = SchedulerObject3.load_config(tmpdirname)
new_obj_3 = SchedulerObject3.from_config(config)
assert new_obj_1.__class__ == SchedulerObject
assert new_obj_2.__class__ == SchedulerObject2
assert new_obj_3.__class__ == SchedulerObject3
assert cap_logger_1.out == ""
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
class SchedulerCommonTest(unittest.TestCase): class SchedulerCommonTest(unittest.TestCase):
scheduler_classes = () scheduler_classes = ()
forward_default_kwargs = () forward_default_kwargs = ()
...@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps) scheduler.set_timesteps(num_inference_steps)
...@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps) scheduler.set_timesteps(num_inference_steps)
...@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps) scheduler.set_timesteps(num_inference_steps)
...@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_compatibles(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
assert all(c is not None for c in scheduler.compatibles)
for comp_scheduler_cls in scheduler.compatibles:
comp_scheduler = comp_scheduler_cls.from_config(scheduler.config)
assert comp_scheduler is not None
new_scheduler = scheduler_class.from_config(comp_scheduler.config)
new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config}
scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config}
# make sure that configs are essentially identical
assert new_scheduler_config == dict(scheduler.config)
# make sure that only differences are for configs that are not in init
init_keys = inspect.signature(scheduler_class.__init__).parameters.keys()
assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set()
def test_from_pretrained(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_pretrained(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
assert scheduler.config == new_scheduler.config
def test_step_shape(self): def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
...@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps) new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
...@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps) new_scheduler.set_timesteps(num_inference_steps)
...@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps) new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
...@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps) new_scheduler.set_timesteps(num_inference_steps)
...@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
output = scheduler.step_pred( output = scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
...@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
output = scheduler.step_pred( output = scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
...@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest): ...@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps) new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
...@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest): ...@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps) new_scheduler.set_timesteps(num_inference_steps)
......
...@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase): ...@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname) new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps) state = scheduler.set_timesteps(state, num_inference_steps)
...@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase): ...@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname) new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps) state = scheduler.set_timesteps(state, num_inference_steps)
...@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase): ...@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname) new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps) state = scheduler.set_timesteps(state, num_inference_steps)
...@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): ...@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname) new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps) state = scheduler.set_timesteps(state, num_inference_steps)
...@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): ...@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname) new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps) state = scheduler.set_timesteps(state, num_inference_steps)
...@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): ...@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname) new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps) state = scheduler.set_timesteps(state, num_inference_steps)
...@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): ...@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname) new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
# copy over dummy past residuals # copy over dummy past residuals
new_state = new_state.replace(ets=dummy_past_residuals[:]) new_state = new_state.replace(ets=dummy_past_residuals[:])
...@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): ...@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname) new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals # copy over dummy past residuals
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment