"vscode:/vscode.git/clone" did not exist on "169fc4add53ef12d1b846294ca1c22465db9aa96"
Unverified Commit 687bc277 authored by Michael's avatar Michael Committed by GitHub
Browse files

add TCD Scheduler (#7174)



* add: support TCD scheduler


---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 6246c70d
...@@ -418,6 +418,8 @@ ...@@ -418,6 +418,8 @@
title: ScoreSdeVeScheduler title: ScoreSdeVeScheduler
- local: api/schedulers/score_sde_vp - local: api/schedulers/score_sde_vp
title: ScoreSdeVpScheduler title: ScoreSdeVpScheduler
- local: api/schedulers/tcd
title: TCDScheduler
- local: api/schedulers/unipc - local: api/schedulers/unipc
title: UniPCMultistepScheduler title: UniPCMultistepScheduler
- local: api/schedulers/vq_diffusion - local: api/schedulers/vq_diffusion
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# TCDScheduler
[Trajectory Consistency Distillation](https://huggingface.co/papers/2402.19159) by Jianbin Zheng, Minghui Hu, Zhongyi Fan, Chaoyue Wang, Changxing Ding, Dacheng Tao and Tat-Jen Cham introduced a Strategic Stochastic Sampling (Algorithm 4) that is capable of generating good samples in a small number of steps. Distinguishing it as an advanced iteration of the multistep scheduler (Algorithm 1) in the [Consistency Models](https://huggingface.co/papers/2303.01469), Strategic Stochastic Sampling specifically tailored for the trajectory consistency function.
The abstract from the paper is:
*Latent Consistency Model (LCM) extends the Consistency Model to the latent space and leverages the guided consistency distillation technique to achieve impressive performance in accelerating text-to-image synthesis. However, we observed that LCM struggles to generate images with both clarity and detailed intricacy. To address this limitation, we initially delve into and elucidate the underlying causes. Our investigation identifies that the primary issue stems from errors in three distinct areas. Consequently, we introduce Trajectory Consistency Distillation (TCD), which encompasses trajectory consistency function and strategic stochastic sampling. The trajectory consistency function diminishes the distillation errors by broadening the scope of the self-consistency boundary condition and endowing the TCD with the ability to accurately trace the entire trajectory of the Probability Flow ODE. Additionally, strategic stochastic sampling is specifically designed to circumvent the accumulated errors inherent in multi-step consistency sampling, which is meticulously tailored to complement the TCD model. Experiments demonstrate that TCD not only significantly enhances image quality at low NFEs but also yields more detailed results compared to the teacher model at high NFEs.*
The original codebase can be found at [jabir-zheng/TCD](https://github.com/jabir-zheng/TCD).
## TCDScheduler
[[autodoc]] TCDScheduler
## TCDSchedulerOutput
[[autodoc]] schedulers.scheduling_tcd.TCDSchedulerOutput
...@@ -160,6 +160,7 @@ else: ...@@ -160,6 +160,7 @@ else:
"SASolverScheduler", "SASolverScheduler",
"SchedulerMixin", "SchedulerMixin",
"ScoreSdeVeScheduler", "ScoreSdeVeScheduler",
"TCDScheduler",
"UnCLIPScheduler", "UnCLIPScheduler",
"UniPCMultistepScheduler", "UniPCMultistepScheduler",
"VQDiffusionScheduler", "VQDiffusionScheduler",
...@@ -545,6 +546,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -545,6 +546,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SASolverScheduler, SASolverScheduler,
SchedulerMixin, SchedulerMixin,
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
TCDScheduler,
UnCLIPScheduler, UnCLIPScheduler,
UniPCMultistepScheduler, UniPCMultistepScheduler,
VQDiffusionScheduler, VQDiffusionScheduler,
......
...@@ -65,6 +65,7 @@ else: ...@@ -65,6 +65,7 @@ else:
_import_structure["scheduling_repaint"] = ["RePaintScheduler"] _import_structure["scheduling_repaint"] = ["RePaintScheduler"]
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"] _import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"] _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
_import_structure["scheduling_tcd"] = ["TCDScheduler"]
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
_import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"] _import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
...@@ -159,6 +160,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -159,6 +160,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .scheduling_repaint import RePaintScheduler from .scheduling_repaint import RePaintScheduler
from .scheduling_sasolver import SASolverScheduler from .scheduling_sasolver import SASolverScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_tcd import TCDScheduler
from .scheduling_unclip import UnCLIPScheduler from .scheduling_unclip import UnCLIPScheduler
from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
......
This diff is collapsed.
...@@ -1095,6 +1095,21 @@ class ScoreSdeVeScheduler(metaclass=DummyObject): ...@@ -1095,6 +1095,21 @@ class ScoreSdeVeScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class TCDScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class UnCLIPScheduler(metaclass=DummyObject): class UnCLIPScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
import torch
from diffusers import TCDScheduler
from .test_schedulers import SchedulerCommonTest
class TCDSchedulerTest(SchedulerCommonTest):
scheduler_classes = (TCDScheduler,)
forward_default_kwargs = (("num_inference_steps", 10),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.00085,
"beta_end": 0.0120,
"beta_schedule": "scaled_linear",
"prediction_type": "epsilon",
}
config.update(**kwargs)
return config
@property
def default_num_inference_steps(self):
return 10
@property
def default_valid_timestep(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
timestep = scheduler.timesteps[-1]
return timestep
def test_timesteps(self):
for timesteps in [100, 500, 1000]:
# 0 is not guaranteed to be in the timestep schedule, but timesteps - 1 is
self.check_over_configs(time_step=timesteps - 1, num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(time_step=self.default_valid_timestep, beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "scaled_linear", "squaredcos_cap_v2"]:
self.check_over_configs(time_step=self.default_valid_timestep, beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(time_step=self.default_valid_timestep, prediction_type=prediction_type)
def test_clip_sample(self):
for clip_sample in [True, False]:
self.check_over_configs(time_step=self.default_valid_timestep, clip_sample=clip_sample)
def test_thresholding(self):
self.check_over_configs(time_step=self.default_valid_timestep, thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(
time_step=self.default_valid_timestep,
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_time_indices(self):
# Get default timestep schedule.
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
for t in timesteps:
self.check_over_forward(time_step=t)
def test_inference_steps(self):
# Hardcoded for now
for t, num_inference_steps in zip([99, 39, 39, 19], [10, 25, 26, 50]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def full_loop(self, num_inference_steps=10, seed=0, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
eta = 0.0 # refer to gamma in the paper
model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(seed)
scheduler.set_timesteps(num_inference_steps)
for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta, generator).prev_sample
return sample
def test_full_loop_onestep_deter(self):
sample = self.full_loop(num_inference_steps=1)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 29.8715) < 1e-3 # 0.0778918
assert abs(result_mean.item() - 0.0389) < 1e-3
def test_full_loop_multistep_deter(self):
sample = self.full_loop(num_inference_steps=10)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 181.2040) < 1e-3
assert abs(result_mean.item() - 0.2359) < 1e-3
def test_custom_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
scheduler.set_timesteps(timesteps=timesteps)
scheduler_timesteps = scheduler.timesteps
for i, timestep in enumerate(scheduler_timesteps):
if i == len(timesteps) - 1:
expected_prev_t = -1
else:
expected_prev_t = timesteps[i + 1]
prev_t = scheduler.previous_timestep(timestep)
prev_t = prev_t.item()
self.assertEqual(prev_t, expected_prev_t)
def test_custom_timesteps_increasing_order(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 51, 0]
with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."):
scheduler.set_timesteps(timesteps=timesteps)
def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
num_inference_steps = len(timesteps)
with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."):
scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps)
def test_custom_timesteps_too_large(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [scheduler.config.num_train_timesteps]
with self.assertRaises(
ValueError,
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
):
scheduler.set_timesteps(timesteps=timesteps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment