Unverified Commit aaaec064 authored by Wenliang Zhao's avatar Wenliang Zhao Committed by GitHub
Browse files

add the UniPC scheduler (#2373)



* add UniPC scheduler

* add the return type to the functions

* code quality check

* add tests

* finish docs

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2777264e
......@@ -148,7 +148,7 @@
- local: api/pipelines/stable_diffusion/upscale
title: Super-Resolution
- local: api/pipelines/stable_diffusion/latent_upscale
title: Stable-Diffusion-Latent-Upscaler
title: Stable-Diffusion-Latent-Upscaler
- local: api/pipelines/stable_diffusion/pix2pix
title: InstructPix2Pix
- local: api/pipelines/stable_diffusion/pix2pix_zero
......@@ -204,6 +204,8 @@
title: Singlestep DPM-Solver
- local: api/schedulers/stochastic_karras_ve
title: Stochastic Kerras VE
- local: api/schedulers/unipc
title: UniPCMultistepScheduler
- local: api/schedulers/score_sde_ve
title: VE-SDE
- local: api/schedulers/score_sde_vp
......
......@@ -43,11 +43,11 @@ To this end, the design of schedulers is such that:
The following table summarizes all officially supported schedulers, their corresponding paper
| Scheduler | Paper |
|---|---|
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) |
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) |
| [deis](./deis) | [**DEISMultistepScheduler**](https://arxiv.org/abs/2204.13902) |
| [singlestep_dpm_solver](./singlestep_dpm_solver) | [**Singlestep DPM-Solver**](https://arxiv.org/abs/2206.00927) |
| [multistep_dpm_solver](./multistep_dpm_solver) | [**Multistep DPM-Solver**](https://arxiv.org/abs/2206.00927) |
| [heun](./heun) | [**Heun scheduler inspired by Karras et. al paper**](https://arxiv.org/abs/2206.00364) |
......@@ -62,6 +62,7 @@ The following table summarizes all officially supported schedulers, their corres
| [euler](./euler) | [**Euler scheduler**](https://arxiv.org/abs/2206.00364) |
| [euler_ancestral](./euler_ancestral) | [**Euler Ancestral scheduler**](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72) |
| [vq_diffusion](./vq_diffusion) | [**VQDiffusionScheduler**](https://arxiv.org/abs/2111.14822) |
| [unipc](./unipc) | [**UniPCMultistepScheduler**](https://arxiv.org/abs/2302.04867) |
| [repaint](./repaint) | [**RePaint scheduler**](https://arxiv.org/abs/2201.09865) |
## API
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# UniPC
## Overview
UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders.
For more details about the method, please refer to the [[paper]](https://arxiv.org/abs/2302.04867) and the [[code]](https://github.com/wl-zhao/UniPC).
Fast Sampling of Diffusion Models with Exponential Integrator.
## UniPCMultistepScheduler
[[autodoc]] UniPCMultistepScheduler
......@@ -84,6 +84,7 @@ else:
SchedulerMixin,
ScoreSdeVeScheduler,
UnCLIPScheduler,
UniPCMultistepScheduler,
VQDiffusionScheduler,
)
from .training_utils import EMAModel
......
......@@ -39,6 +39,7 @@ else:
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_unipc_multistep import UniPCMultistepScheduler
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
......
......@@ -159,7 +159,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["logrho"]:
if solver_type in ["midpoint", "heun"]:
if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
solver_type = "logrho"
else:
raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")
......
......@@ -169,7 +169,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
if solver_type == "logrho":
if solver_type in ["logrho", "bh1", "bh2"]:
solver_type = "midpoint"
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
......@@ -168,7 +168,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
if solver_type == "logrho":
if solver_type in ["logrho", "bh1", "bh2"]:
solver_type = "midpoint"
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
This diff is collapsed.
......@@ -42,6 +42,7 @@ class KarrasDiffusionSchedulers(Enum):
KDPM2DiscreteScheduler = 10
KDPM2AncestralDiscreteScheduler = 11
DEISMultistepScheduler = 12
UniPCMultistepScheduler = 13
@dataclass
......
......@@ -600,6 +600,21 @@ class UnCLIPScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class UniPCMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class VQDiffusionScheduler(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -40,6 +40,7 @@ from diffusers import (
PNDMScheduler,
ScoreSdeVeScheduler,
UnCLIPScheduler,
UniPCMultistepScheduler,
VQDiffusionScheduler,
logging,
)
......@@ -2636,6 +2637,200 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
assert sample.dtype == torch.float16
class UniPCMultistepSchedulerTest(SchedulerCommonTest):
scheduler_classes = (UniPCMultistepScheduler,)
forward_default_kwargs = (("num_inference_steps", 25),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"solver_order": 2,
}
config.update(**kwargs)
return config
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals (must be after setting timesteps)
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residual (must be after setting timesteps)
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
return sample
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
time_step_0 = scheduler.timesteps[5]
time_step_1 = scheduler.timesteps[6]
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for order in [1, 2, 3]:
for solver_type in ["bh1", "bh2"]:
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
solver_order=order,
solver_type=solver_type,
)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_solver_order_and_type(self):
for solver_type in ["bh1", "bh2"]:
for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
)
assert not torch.isnan(sample).any(), "Samples have nan numbers"
def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)
def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.1096) < 1e-3
def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter.half()
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
assert sample.dtype == torch.float16
class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (KDPM2AncestralDiscreteScheduler,)
num_inference_steps = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment