"tests/vscode:/vscode.git/clone" did not exist on "4b96addca0597af7acd2f38785aaf0cd40156166"
Unverified Commit 4c54519e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add 2nd order heun scheduler (#1336)

* Add heun

* Finish first version of heun

* remove bogus

* finish

* finish

* improve

* up

* up

* fix more

* change progress bar

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

* finish

* up

* up

* up
parent 25f11424
...@@ -77,6 +77,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -77,6 +77,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
""" """
order = 2
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -68,6 +68,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -68,6 +68,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
order = 1
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -90,6 +90,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -90,6 +90,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
order = 1
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -102,6 +102,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -102,6 +102,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
""" """
order = 1
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -66,6 +66,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -66,6 +66,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
correct_steps (`int`): number of correction steps performed on a produced sample. correct_steps (`int`): number of correction steps performed on a produced sample.
""" """
order = 1
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -38,6 +38,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -38,6 +38,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
""" """
order = 1
@register_to_config @register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
self.sigmas = None self.sigmas = None
......
...@@ -138,6 +138,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): ...@@ -138,6 +138,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
The ending cumulative gamma value. The ending cumulative gamma value.
""" """
order = 1
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -83,6 +83,7 @@ _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ ...@@ -83,6 +83,7 @@ _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
"PNDMScheduler", "PNDMScheduler",
"LMSDiscreteScheduler", "LMSDiscreteScheduler",
"EulerDiscreteScheduler", "EulerDiscreteScheduler",
"HeunDiscreteScheduler",
"EulerAncestralDiscreteScheduler", "EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler", "DPMSolverMultistepScheduler",
] ]
...@@ -362,6 +362,21 @@ class EulerDiscreteScheduler(metaclass=DummyObject): ...@@ -362,6 +362,21 @@ class EulerDiscreteScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class HeunDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class IPNDMScheduler(metaclass=DummyObject): class IPNDMScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -928,7 +928,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -928,7 +928,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
prompt = "astronaut riding a horse" prompt = "astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np") output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
image = output.images[0] image = output.images[0]
assert image.shape == (512, 512, 3) assert image.shape == (512, 512, 3)
...@@ -980,7 +980,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -980,7 +980,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
callback_steps=1, callback_steps=1,
) )
assert test_callback_fn.has_been_called assert test_callback_fn.has_been_called
assert number_of_steps == 51 assert number_of_steps == 50
def test_stable_diffusion_low_cpu_mem_usage(self): def test_stable_diffusion_low_cpu_mem_usage(self):
pipeline_id = "CompVis/stable-diffusion-v1-4" pipeline_id = "CompVis/stable-diffusion-v1-4"
......
...@@ -351,7 +351,7 @@ class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase): ...@@ -351,7 +351,7 @@ class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
assert latents.shape == (1, 4, 64, 64) assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1] latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([1.83, 1.293, -0.09705, 1.256, -2.293, 1.091, -0.0809, -0.65, -2.953]) expected_slice = np.array([1.83, 1.293, -0.09705, 1.256, -2.293, 1.091, -0.0809, -0.65, -2.953])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
elif step == 37: elif step == 37:
latents = latents.detach().cpu().numpy() latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64) assert latents.shape == (1, 4, 64, 64)
...@@ -386,7 +386,7 @@ class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase): ...@@ -386,7 +386,7 @@ class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
callback_steps=1, callback_steps=1,
) )
assert test_callback_fn.has_been_called assert test_callback_fn.has_been_called
assert number_of_steps == 51 assert number_of_steps == 50
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -635,7 +635,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -635,7 +635,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
callback_steps=1, callback_steps=1,
) )
assert test_callback_fn.has_been_called assert test_callback_fn.has_been_called
assert number_of_steps == 38 assert number_of_steps == 37
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -484,4 +484,4 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): ...@@ -484,4 +484,4 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
callback_steps=1, callback_steps=1,
) )
assert test_callback_fn.has_been_called assert test_callback_fn.has_been_called
assert number_of_steps == 38 assert number_of_steps == 37
...@@ -692,7 +692,7 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase): ...@@ -692,7 +692,7 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
callback_steps=1, callback_steps=1,
) )
assert test_callback_fn.has_been_called assert test_callback_fn.has_been_called
assert number_of_steps == 21 assert number_of_steps == 20
def test_stable_diffusion_low_cpu_mem_usage(self): def test_stable_diffusion_low_cpu_mem_usage(self):
pipeline_id = "stabilityai/stable-diffusion-2-base" pipeline_id = "stabilityai/stable-diffusion-2-base"
......
...@@ -306,7 +306,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase): ...@@ -306,7 +306,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, 253:256, 253:256, -1] image_slice = image[0, 253:256, 253:256, -1]
assert image.shape == (1, 768, 768, 3) assert image.shape == (1, 768, 768, 3)
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) expected_slice = np.array([0.2049, 0.2115, 0.2323, 0.2416, 0.256, 0.2484, 0.2517, 0.2358, 0.236])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_attention_slicing_v_pred(self): def test_stable_diffusion_attention_slicing_v_pred(self):
......
...@@ -654,7 +654,10 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -654,7 +654,10 @@ class PipelineSlowTests(unittest.TestCase):
force_download=True, force_download=True,
) )
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n" assert (
cap_logger.out
== "Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.\n"
)
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
# 1. Load models # 1. Load models
......
...@@ -30,6 +30,7 @@ from diffusers import ( ...@@ -30,6 +30,7 @@ from diffusers import (
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
HeunDiscreteScheduler,
IPNDMScheduler, IPNDMScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
...@@ -1876,3 +1877,95 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest): ...@@ -1876,3 +1877,95 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest):
def test_add_noise_device(self): def test_add_noise_device(self):
pass pass
class HeunDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (HeunDiscreteScheduler,)
num_inference_steps = 10
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1100,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
}
config.update(**kwargs)
return config
def test_timesteps(self):
for timesteps in [10, 50, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "scaled_linear"]:
self.check_over_configs(beta_schedule=schedule)
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
model = self.dummy_model()
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if str(torch_device).startswith("cpu"):
# The following sum varies between 148 and 156 on mps. Why?
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
elif str(torch_device).startswith("mps"):
# Larger tolerance on mps
assert abs(result_mean.item() - 0.0002) < 1e-2
else:
# CUDA
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment