Unverified Commit 940f9410 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Add `test_full_loop_with_noise` tests to all scheduler with `add_nosie` function (#5184)



* add fast tests for dpm-multi

* add more tests

* style

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent ad06e510
......@@ -115,6 +115,45 @@ class CMStochasticIterativeSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 347.6357) < 1e-2
assert abs(result_mean.item() - 0.4527) < 1e-3
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
t_start = 8
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for t in timesteps:
# 1. scale model input
scaled_sample = scheduler.scale_model_input(sample, t)
# 2. predict noise residual
residual = model(scaled_sample, t)
# 3. predict previous sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
sample = pred_prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 763.9186) < 1e-2, f" expected result sum 763.9186, but get {result_sum}"
assert abs(result_mean.item() - 0.9947) < 1e-3, f" expected result mean 0.9947, but get {result_mean}"
def test_custom_timesteps_increasing_order(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
......
......@@ -146,3 +146,31 @@ class DDIMSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 149.0784) < 1e-2
assert abs(result_mean.item() - 0.1941) < 1e-3
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.0
t_start = 8
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for t in timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 354.5418) < 1e-2, f" expected result sum 218.4379, but get {result_sum}"
assert abs(result_mean.item() - 0.4616) < 1e-3, f" expected result mean 0.2844, but get {result_mean}"
......@@ -186,3 +186,31 @@ class DDIMParallelSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 149.0784) < 1e-2
assert abs(result_mean.item() - 0.1941) < 1e-3
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.0
t_start = 8
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for t in timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 354.5418) < 1e-2, f" expected result sum 354.5418, but get {result_sum}"
assert abs(result_mean.item() - 0.4616) < 1e-3, f" expected result mean 0.4616, but get {result_mean}"
......@@ -185,3 +185,34 @@ class DDPMSchedulerTest(SchedulerCommonTest):
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
):
scheduler.set_timesteps(timesteps=timesteps)
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_trained_timesteps = len(scheduler)
t_start = num_trained_timesteps - 2
model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(0)
# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for t in timesteps:
# 1. predict noise residual
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
sample = pred_prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 387.9466) < 1e-2, f" expected result sum 387.9466, but get {result_sum}"
assert abs(result_mean.item() - 0.5051) < 1e-3, f" expected result mean 0.5051, but get {result_mean}"
......@@ -214,3 +214,34 @@ class DDPMParallelSchedulerTest(SchedulerCommonTest):
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
):
scheduler.set_timesteps(timesteps=timesteps)
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_trained_timesteps = len(scheduler)
t_start = num_trained_timesteps - 2
model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(0)
# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for t in timesteps:
# 1. predict noise residual
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
sample = pred_prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 387.9466) < 1e-2, f" expected result sum 387.9466, but get {result_sum}"
assert abs(result_mean.item() - 0.5051) < 1e-3, f" expected result mean 0.5051, but get {result_mean}"
......@@ -236,3 +236,30 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
sample = scheduler.step(residual, t, sample).prev_sample
assert sample.dtype == torch.float16
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
t_start = 8
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 315.3016) < 1e-2, f" expected result sum 315.3016, but get {result_sum}"
assert abs(result_mean.item() - 0.41054) < 1e-3, f" expected result mean 0.41054, but get {result_mean}"
......@@ -213,6 +213,33 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.3301) < 1e-3
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
t_start = 5
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 318.4111) < 1e-2, f" expected result sum 318.4111, but get {result_sum}"
assert abs(result_mean.item() - 0.4146) < 1e-3, f" expected result mean 0.4146, but get {result_mean}"
def test_full_loop_no_noise_thres(self):
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
result_mean = torch.mean(torch.abs(sample))
......
......@@ -279,3 +279,30 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
t_start = 5
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 269.2187) < 1e-2, f" expected result sum 269.2187, but get {result_sum}"
assert abs(result_mean.item() - 0.3505) < 1e-3, f" expected result mean 0.3505, but get {result_mean}"
......@@ -144,3 +144,36 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 124.52299499511719) < 1e-2
assert abs(result_mean.item() - 0.16213932633399963) < 1e-3
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
# add noise
t_start = self.num_inference_steps - 2
noise = self.dummy_noise_deter
noise = noise.to(sample.device)
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 57062.9297) < 1e-2, f" expected result sum 57062.9297, but get {result_sum}"
assert abs(result_mean.item() - 74.3007) < 1e-3, f" expected result mean 74.3007, but get {result_mean}"
......@@ -116,3 +116,37 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
t_start = self.num_inference_steps - 2
scheduler.set_timesteps(self.num_inference_steps)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
# add noise
noise = self.dummy_noise_deter
noise = noise.to(sample.device)
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 56163.0508) < 1e-2, f" expected result sum 56163.0508, but get {result_sum}"
assert abs(result_mean.item() - 73.1290) < 1e-3, f" expected result mean 73.1290, but get {result_mean}"
......@@ -158,3 +158,34 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 0.00015) < 1e-2
assert abs(result_mean.item() - 1.9869554535034695e-07) < 1e-2
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
t_start = self.num_inference_steps - 2
noise = self.dummy_noise_deter
noise = noise.to(torch_device)
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 75074.8906) < 1e-2, f" expected result sum 75074.8906, but get {result_sum}"
assert abs(result_mean.item() - 97.7538) < 1e-3, f" expected result mean 97.7538, but get {result_mean}"
......@@ -121,3 +121,38 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 13849.3818) < 1e-1
assert abs(result_mean.item() - 18.0331) < 1e-3
def test_full_loop_with_noise(self):
if torch_device == "mps":
return
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
# add noise
t_start = self.num_inference_steps - 2
noise = self.dummy_noise_deter
noise = noise.to(sample.device)
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 93087.0312) < 1e-2, f" expected result sum 93087.0312, but get {result_sum}"
assert abs(result_mean.item() - 121.2071) < 5e-3, f" expected result mean 121.2071, but get {result_mean}"
......@@ -130,3 +130,37 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest):
# CUDA
assert abs(result_sum.item() - 20.4125) < 1e-2
assert abs(result_mean.item() - 0.0266) < 1e-3
def test_full_loop_with_noise(self):
if torch_device == "mps":
return
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
# add noise
t_start = self.num_inference_steps - 2
noise = self.dummy_noise_deter
noise = noise.to(sample.device)
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 70408.4062) < 1e-2, f" expected result sum 70408.4062, but get {result_sum}"
assert abs(result_mean.item() - 91.6776) < 1e-3, f" expected result mean 91.6776, but get {result_mean}"
......@@ -138,3 +138,33 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 3812.9927) < 2e-2
assert abs(result_mean.item() - 4.9648) < 1e-3
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
# add noise
t_start = self.num_inference_steps - 2
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 27663.6895) < 1e-2
assert abs(result_mean.item() - 36.0204) < 1e-3
......@@ -242,3 +242,30 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
sample = scheduler.step(residual, t, sample).prev_sample
assert sample.dtype == torch.float16
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
t_start = 8
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}"
assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}"
......@@ -264,6 +264,21 @@ class SchedulerCommonTest(unittest.TestCase):
return sample
@property
def dummy_noise_deter(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
num_elems = batch_size * num_channels * height * width
sample = torch.arange(num_elems).flip(-1)
sample = sample.reshape(num_channels, height, width, batch_size)
sample = sample / num_elems
sample = sample.permute(3, 0, 1, 2)
return sample
@property
def dummy_sample_deter(self):
batch_size = 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment