Unverified Commit be99201a authored by qsh-zh's avatar qsh-zh Committed by GitHub
Browse files

feat : add log-rho deis multistep scheduler (#1432)



* feat : add log-rho deis multistep deis

* docs :fix typo

* docs : add docs for impl algo

* docs : remove duplicate ref

* finish deis

* add docs

* fix
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 9b638548
...@@ -155,6 +155,8 @@ ...@@ -155,6 +155,8 @@
title: "DDIM" title: "DDIM"
- local: api/schedulers/ddpm - local: api/schedulers/ddpm
title: "DDPM" title: "DDPM"
- local: api/schedulers/deis
title: "DEIS"
- local: api/schedulers/singlestep_dpm_solver - local: api/schedulers/singlestep_dpm_solver
title: "Singlestep DPM-Solver" title: "Singlestep DPM-Solver"
- local: api/schedulers/multistep_dpm_solver - local: api/schedulers/multistep_dpm_solver
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# DEIS
Fast Sampling of Diffusion Models with Exponential Integrator.
## Overview
Original paper can be found [here](https://arxiv.org/abs/2204.13902). The original implementation can be found [here](https://github.com/qsh-zh/deis).
## DEISMultistepScheduler
[[autodoc]] DEISMultistepScheduler
...@@ -67,6 +67,7 @@ else: ...@@ -67,6 +67,7 @@ else:
from .schedulers import ( from .schedulers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler, DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
......
...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable: ...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
else: else:
from .scheduling_ddim import DDIMScheduler from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm import DDPMScheduler
from .scheduling_deis_multistep import DEISMultistepScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
......
This diff is collapsed.
...@@ -174,9 +174,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -174,9 +174,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# settings for DPM-Solver # settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if algorithm_type == "deis":
algorithm_type = "dpmsolver++"
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]: if solver_type not in ["midpoint", "heun"]:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") if solver_type == "logrho":
solver_type = "midpoint"
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
......
...@@ -163,9 +163,15 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -163,9 +163,15 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# settings for DPM-Solver # settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if algorithm_type == "deis":
algorithm_type = "dpmsolver++"
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]: if solver_type not in ["midpoint", "heun"]:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") if solver_type == "logrho":
solver_type = "midpoint"
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
......
...@@ -41,4 +41,7 @@ _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ ...@@ -41,4 +41,7 @@ _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
"EulerAncestralDiscreteScheduler", "EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler", "DPMSolverMultistepScheduler",
"DPMSolverSinglestepScheduler", "DPMSolverSinglestepScheduler",
"KDPM2DiscreteScheduler",
"KDPM2AncestralDiscreteScheduler",
"DEISMultistepScheduler",
] ]
...@@ -362,6 +362,21 @@ class DDPMScheduler(metaclass=DummyObject): ...@@ -362,6 +362,21 @@ class DDPMScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class DEISMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DPMSolverMultistepScheduler(metaclass=DummyObject): class DPMSolverMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -27,6 +27,7 @@ import diffusers ...@@ -27,6 +27,7 @@ import diffusers
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler, DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
...@@ -2505,6 +2506,207 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest): ...@@ -2505,6 +2506,207 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.0266) < 1e-3 assert abs(result_mean.item() - 0.0266) < 1e-3
class DEISMultistepSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DEISMultistepScheduler,)
forward_default_kwargs = (("num_inference_steps", 25),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"solver_order": 2,
}
config.update(**kwargs)
return config
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_save_pretrained(self):
pass
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals (must be after setting timesteps)
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residual (must be after setting timesteps)
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
return sample
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
time_step_0 = scheduler.timesteps[5]
time_step_1 = scheduler.timesteps[6]
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for order in [1, 2, 3]:
for solver_type in ["logrho"]:
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
algorithm_type="deis",
solver_order=order,
solver_type=solver_type,
)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_solver_order_and_type(self):
for algorithm_type in ["deis"]:
for solver_type in ["logrho"]:
for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
assert not torch.isnan(sample).any(), "Samples have nan numbers"
def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)
def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.23916) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.091) < 1e-3
def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter.half()
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
assert sample.dtype == torch.float16
class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest): class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (KDPM2AncestralDiscreteScheduler,) scheduler_classes = (KDPM2AncestralDiscreteScheduler,)
num_inference_steps = 10 num_inference_steps = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment