Unverified Commit cedafb86 authored by Dudu Moshe's avatar Dudu Moshe Committed by GitHub
Browse files

[Bug]: fix DDPM scheduler arbitrary infer steps count. (#2076)



scheduling_ddpm: fix evaluate with lower timesteps count than train.
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 69caa964
......@@ -189,13 +189,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None):
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
prev_t = t - self.config.num_train_timesteps // num_inference_steps
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
if variance_type is None:
variance_type = self.config.variance_type
......@@ -208,10 +211,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = torch.log(torch.clamp(variance, min=1e-20))
variance = torch.exp(0.5 * variance)
elif variance_type == "fixed_large":
variance = self.betas[t]
variance = current_beta_t
elif variance_type == "fixed_large_log":
# Glide max_log
variance = torch.log(self.betas[t])
variance = torch.log(current_beta_t)
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
......@@ -249,6 +252,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
t = timestep
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
......@@ -257,9 +262,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
......@@ -281,8 +288,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
......
......@@ -118,7 +118,7 @@ class PipelineFastTests(unittest.TestCase):
assert image.height == self.dummy_unet.sample_size[0] and image.width == self.dummy_unet.sample_size[1]
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
image_from_tuple_slice = np.frombuffer(image_from_tuple.tobytes(), dtype="uint8")[:10]
expected_slice = np.array([255, 255, 255, 0, 181, 0, 124, 0, 15, 255])
expected_slice = np.array([69, 255, 255, 255, 0, 0, 77, 181, 12, 127])
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() == 0
......
......@@ -40,7 +40,7 @@ class DDPMPipelineFastTests(unittest.TestCase):
)
return model
def test_inference(self):
def test_fast_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler()
......@@ -60,7 +60,33 @@ class DDPMPipelineFastTests(unittest.TestCase):
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
[9.956e-01, 5.785e-01, 4.675e-01, 9.930e-01, 0.0, 1.000, 1.199e-03, 2.648e-04, 5.101e-04]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_full_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ddpm(generator=generator, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[1.0, 3.495e-02, 2.939e-01, 9.821e-01, 9.448e-01, 6.261e-03, 7.998e-01, 8.9e-01, 1.122e-02]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......
......@@ -686,8 +686,8 @@ class DDPMSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 258.9070) < 1e-2
assert abs(result_mean.item() - 0.3374) < 1e-3
assert abs(result_sum.item() - 258.9606) < 1e-2
assert abs(result_mean.item() - 0.3372) < 1e-3
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
......@@ -717,8 +717,8 @@ class DDPMSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 201.9864) < 1e-2
assert abs(result_mean.item() - 0.2630) < 1e-3
assert abs(result_sum.item() - 202.0296) < 1e-2
assert abs(result_mean.item() - 0.2631) < 1e-3
class DDIMSchedulerTest(SchedulerCommonTest):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment