Unverified Commit a54cfe68 authored by Sid Sahai's avatar Sid Sahai Committed by GitHub
Browse files

Add LMSDiscreteSchedulerTest (#467)



* [WIP] add LMSDiscreteSchedulerTest

* fixes for comments

* add torch numpy test

* rebase

* Update tests/test_scheduler.py

* Update tests/test_scheduler.py

* style

* return residuals
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent 88972172
...@@ -19,7 +19,7 @@ from typing import Dict, List, Tuple ...@@ -19,7 +19,7 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -853,3 +853,83 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -853,3 +853,83 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
class LMSDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (LMSDiscreteScheduler,)
num_inference_steps = 10
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1100,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
"tensor_format": "pt",
}
config.update(**kwargs)
return config
def test_timesteps(self):
for timesteps in [10, 50, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "scaled_linear"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self):
for t in [0, 500, 800]:
self.check_over_forward(time_step=t)
def test_pytorch_equal_numpy(self):
for scheduler_class in self.scheduler_classes:
sample_pt = self.dummy_sample
residual_pt = 0.1 * sample_pt
sample = sample_pt.numpy()
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler_config["tensor_format"] = "np"
scheduler = scheduler_class(**scheduler_config)
scheduler_config["tensor_format"] = "pt"
scheduler_pt = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
scheduler_pt.set_timesteps(self.num_inference_steps)
output = scheduler.step(residual, 1, sample).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.sigmas[0]
for i, t in enumerate(scheduler.timesteps):
sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5)
model_output = model(sample, t)
output = scheduler.step(model_output, i, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment