Unverified Commit 07c9a08e authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Add `timestep_spacing` and `steps_offset` to schedulers (#3947)



* Add timestep_spacing to DDPM, LMSDiscrete, PNDM.

* Remove spurious line.

* More easy schedulers.

* Add `linspace` to DDIM

* Noise sigma for `trailing`.

* Add timestep_spacing to DEISMultistepScheduler.

Not sure the range is the way it was intended.

* Fix: remove line used to debug.

* Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC

* Fix: convert to numpy.

* Use sched. defaults when instantiating from_config

For params not present in the original configuration.

This makes it possible to switch pipeline schedulers even if they use
different timestep_spacing (or any other param).

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Missing args in DPMSolverMultistep

* Test: default args not in config

* Style

* Fix scheduler name in test

* Remove duplicated entries

* Add test for solver_type

This test currently fails in main. When switching from DEIS to UniPC,
solver_type is "logrho" (the default value from DEIS), which gets
translated to "bh1" by UniPC. This is different to the default value for
UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171



* UniPC: use same default for solver_type

Fixes a bug when switching from UniPC from another scheduler (i.e.,
DEIS) that uses a different solver type. The solver is now the same as
if we had instantiated the scheduler directly.

* do not save use default values

* fix more

* fix all

* fix schedulers

* fix more

* finish for real

* finish for real

* flaky tests

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

* Default steps_offset to 0.

* Add missing docstrings

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2837d490
......@@ -97,7 +97,7 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu()
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
......
......@@ -23,7 +23,7 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
"beta_end": 0.02,
"beta_schedule": "linear",
"solver_order": 2,
"solver_type": "bh1",
"solver_type": "bh2",
}
config.update(**kwargs)
......@@ -144,7 +144,7 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
assert abs(result_mean.item() - 0.2464) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
......@@ -154,7 +154,7 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
assert abs(result_mean.item() - 0.2464) < 1e-3
def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
......@@ -206,13 +206,13 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
assert abs(result_mean.item() - 0.2464) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.1096) < 1e-3
assert abs(result_mean.item() - 0.1014) < 1e-3
def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0]
......
......@@ -24,10 +24,14 @@ import torch
import diffusers
from diffusers import (
DDIMScheduler,
DEISMultistepScheduler,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
IPNDMScheduler,
LMSDiscreteScheduler,
UniPCMultistepScheduler,
VQDiffusionScheduler,
logging,
)
......@@ -202,6 +206,44 @@ class SchedulerBaseTests(unittest.TestCase):
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
def test_default_arguments_not_in_config(self):
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16
)
assert pipe.scheduler.__class__ == DDIMScheduler
# Default for DDIMScheduler
assert pipe.scheduler.config.timestep_spacing == "leading"
# Switch to a different one, verify we use the default for that class
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
assert pipe.scheduler.config.timestep_spacing == "linspace"
# Override with kwargs
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
assert pipe.scheduler.config.timestep_spacing == "trailing"
# Verify overridden kwargs stick
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
assert pipe.scheduler.config.timestep_spacing == "trailing"
# And stick
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
assert pipe.scheduler.config.timestep_spacing == "trailing"
def test_default_solver_type_after_switch(self):
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16
)
assert pipe.scheduler.__class__ == DDIMScheduler
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
assert pipe.scheduler.config.solver_type == "logrho"
# Switch to UniPC, verify the solver is the default
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
assert pipe.scheduler.config.solver_type == "bh2"
class SchedulerCommonTest(unittest.TestCase):
scheduler_classes = ()
......@@ -414,7 +456,11 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_pretrained(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
assert scheduler.config == new_scheduler.config
# `_use_default_values` should not exist for just saved & loaded scheduler
scheduler_config = dict(scheduler.config)
del scheduler_config["_use_default_values"]
assert scheduler_config == new_scheduler.config
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment