Unverified Commit c8b0f0eb authored by Leng Yue's avatar Leng Yue Committed by GitHub
Browse files

Update UniPC to support 1D diffusion. (#5199)



* Update Unipc einsum to support 1D and 3D diffusion.

* Add unittest

* Update unittest & edge case

* Fix unittest

* Fix testing_utils.py

* Fix unittest file

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7a4324cc
...@@ -282,13 +282,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -282,13 +282,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
https://arxiv.org/abs/2205.11487 https://arxiv.org/abs/2205.11487
""" """
dtype = sample.dtype dtype = sample.dtype
batch_size, channels, height, width = sample.shape batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64): if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image # Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width) sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value" abs_sample = sample.abs() # "a certain percentile absolute pixel value"
...@@ -300,7 +300,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -300,7 +300,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width) sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype) sample = sample.to(dtype)
return sample return sample
...@@ -534,14 +534,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -534,14 +534,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.predict_x0: if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None: if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else: else:
pred_res = 0 pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res x_t = x_t_ - alpha_t * B_h * pred_res
else: else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None: if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else: else:
pred_res = 0 pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res x_t = x_t_ - sigma_t * B_h * pred_res
...@@ -670,7 +670,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -670,7 +670,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.predict_x0: if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None: if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else: else:
corr_res = 0 corr_res = 0
D1_t = model_t - m0 D1_t = model_t - m0
...@@ -678,7 +678,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -678,7 +678,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
else: else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None: if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else: else:
corr_res = 0 corr_res = 0
D1_t = model_t - m0 D1_t = model_t - m0
......
...@@ -269,3 +269,113 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -269,3 +269,113 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}" assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}"
assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}" assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}"
class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest):
@property
def dummy_sample(self):
batch_size = 4
num_channels = 3
width = 8
sample = torch.rand((batch_size, num_channels, width))
return sample
@property
def dummy_noise_deter(self):
batch_size = 4
num_channels = 3
width = 8
num_elems = batch_size * num_channels * width
sample = torch.arange(num_elems).flip(-1)
sample = sample.reshape(num_channels, width, batch_size)
sample = sample / num_elems
sample = sample.permute(2, 0, 1)
return sample
@property
def dummy_sample_deter(self):
batch_size = 4
num_channels = 3
width = 8
num_elems = batch_size * num_channels * width
sample = torch.arange(num_elems)
sample = sample.reshape(num_channels, width, batch_size)
sample = sample / num_elems
sample = sample.permute(2, 0, 1)
return sample
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2441) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2441) < 1e-3
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2441) < 1e-3
def test_full_loop_with_karras(self):
sample = self.full_loop(use_karras_sigmas=True)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2898) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.1014) < 1e-3
def test_full_loop_with_karras_and_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.1944) < 1e-3
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
t_start = 8
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}"
assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment