Unverified Commit c18941b0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Better scheduler docs] Improve usage examples of schedulers (#890)



* [Better scheduler docs] Improve usage examples of schedulers

* finish

* fix warnings and add test

* finish

* more replacements

* adapt fast tests hf token

* correct more

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Integrate compatibility with euler
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent a1ea8c01
...@@ -68,6 +68,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -68,6 +68,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"PNDMScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -67,6 +67,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -67,6 +67,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -88,6 +88,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -88,6 +88,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -644,13 +644,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -644,13 +644,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
scheduler = DDIMScheduler( scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
sd_pipe.scheduler = scheduler sd_pipe.scheduler = scheduler
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
......
...@@ -523,9 +523,8 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -523,9 +523,8 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((768, 512)) init_image = init_image.resize((768, 512))
expected_image = np.array(expected_image, dtype=np.float32) / 255.0 expected_image = np.array(expected_image, dtype=np.float32) / 255.0
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, model_id,
scheduler=lms, scheduler=lms,
......
...@@ -366,8 +366,8 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): ...@@ -366,8 +366,8 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
) )
expected_image = np.array(expected_image, dtype=np.float32) / 255.0 expected_image = np.array(expected_image, dtype=np.float32) / 255.0
pndm = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True)
model_id = "runwayml/stable-diffusion-inpainting" model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, safety_checker=None, scheduler=pndm, device_map="auto" model_id, safety_checker=None, scheduler=pndm, device_map="auto"
) )
......
...@@ -407,9 +407,8 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): ...@@ -407,9 +407,8 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
) )
expected_image = np.array(expected_image, dtype=np.float32) / 255.0 expected_image = np.array(expected_image, dtype=np.float32) / 255.0
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, model_id,
scheduler=lms, scheduler=lms,
......
...@@ -13,10 +13,15 @@ ...@@ -13,10 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os
import tempfile import tempfile
import unittest import unittest
import diffusers
from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, logging
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils.testing_utils import CaptureLogger
class SampleObject(ConfigMixin): class SampleObject(ConfigMixin):
...@@ -34,6 +39,37 @@ class SampleObject(ConfigMixin): ...@@ -34,6 +39,37 @@ class SampleObject(ConfigMixin):
pass pass
class SampleObject2(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
f=[1, 3],
):
pass
class SampleObject3(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
f=[1, 3],
):
pass
class ConfigTester(unittest.TestCase): class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self): def test_load_not_from_mixin(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -97,3 +133,151 @@ class ConfigTester(unittest.TestCase): ...@@ -97,3 +133,151 @@ class ConfigTester(unittest.TestCase):
assert config.pop("c") == (2, 5) # instantiated as tuple assert config.pop("c") == (2, 5) # instantiated as tuple
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
assert config == new_config assert config == new_config
def test_save_load_from_different_config(self):
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
new_obj_1 = SampleObject2.from_config(tmpdirname)
# now save a config parameter that is not expected
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
data = json.load(f)
data["unexpected"] = True
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger_2:
new_obj_2 = SampleObject.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_3:
new_obj_3 = SampleObject2.from_config(tmpdirname)
assert new_obj_1.__class__ == SampleObject2
assert new_obj_2.__class__ == SampleObject
assert new_obj_3.__class__ == SampleObject2
assert cap_logger_1.out == ""
assert (
cap_logger_2.out
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.\n"
)
assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out
def test_save_load_compatible_schedulers(self):
SampleObject2._compatible_classes = ["SampleObject"]
SampleObject._compatible_classes = ["SampleObject2"]
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
setattr(diffusers, "SampleObject2", SampleObject2)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
# now save a config parameter that is expected by another class, but not origin class
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
data = json.load(f)
data["f"] = [0, 0]
data["unexpected"] = True
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger:
new_obj = SampleObject.from_config(tmpdirname)
assert new_obj.__class__ == SampleObject
assert (
cap_logger.out
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.\n"
)
def test_save_load_from_different_config_comp_schedulers(self):
SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"]
SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"]
SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"]
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
setattr(diffusers, "SampleObject2", SampleObject2)
setattr(diffusers, "SampleObject3", SampleObject3)
logger = logging.get_logger("diffusers.configuration_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
new_obj_1 = SampleObject.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_2:
new_obj_2 = SampleObject2.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_3:
new_obj_3 = SampleObject3.from_config(tmpdirname)
assert new_obj_1.__class__ == SampleObject
assert new_obj_2.__class__ == SampleObject2
assert new_obj_3.__class__ == SampleObject3
assert cap_logger_1.out == ""
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
def test_load_ddim_from_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
ddim = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
assert ddim.__class__ == DDIMScheduler
# no warning should be thrown
assert cap_logger.out == ""
def test_load_ddim_from_euler(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
euler = EulerDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
assert euler.__class__ == EulerDiscreteScheduler
# no warning should be thrown
assert cap_logger.out == ""
def test_load_ddim_from_euler_ancestral(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
euler = EulerAncestralDiscreteScheduler.from_config(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler"
)
assert euler.__class__ == EulerAncestralDiscreteScheduler
# no warning should be thrown
assert cap_logger.out == ""
def test_load_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
pndm = PNDMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
assert pndm.__class__ == PNDMScheduler
# no warning should be thrown
assert cap_logger.out == ""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment