Commit 12b10cbe authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finish refactor

parent 2d97544d
...@@ -270,7 +270,7 @@ def reset_format() -> None: ...@@ -270,7 +270,7 @@ def reset_format() -> None:
def warning_advice(self, *args, **kwargs): def warning_advice(self, *args, **kwargs):
""" """
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this This method is identical to `logger.warninging()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed warning will not be printed
""" """
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False) no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False)
......
...@@ -19,11 +19,10 @@ import unittest ...@@ -19,11 +19,10 @@ import unittest
import torch import torch
from diffusers import GaussianDDPMScheduler, UNetModel, DDIMScheduler from diffusers import DDIM, DDPM, DDIMScheduler, GaussianDDPMScheduler, LatentDiffusion, UNetModel
from diffusers import DDIM, DDPM, LatentDiffusion
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, torch_device, slow from diffusers.testing_utils import floats_tensor, slow, torch_device
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -149,6 +148,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -149,6 +148,7 @@ class PipelineTesterMixin(unittest.TestCase):
unet = UNetModel.from_pretrained(model_id) unet = UNetModel.from_pretrained(model_id)
noise_scheduler = GaussianDDPMScheduler.from_config(model_id) noise_scheduler = GaussianDDPMScheduler.from_config(model_id)
noise_scheduler = noise_scheduler.set_format("pt")
ddpm = DDPM(unet=unet, noise_scheduler=noise_scheduler) ddpm = DDPM(unet=unet, noise_scheduler=noise_scheduler)
image = ddpm(generator=generator) image = ddpm(generator=generator)
...@@ -165,7 +165,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -165,7 +165,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id) unet = UNetModel.from_pretrained(model_id)
noise_scheduler = DDIMScheduler() noise_scheduler = DDIMScheduler(tensor_format="pt")
ddim = DDIM(unet=unet, noise_scheduler=noise_scheduler) ddim = DDIM(unet=unet, noise_scheduler=noise_scheduler)
image = ddim(generator=generator, eta=0.0) image = ddim(generator=generator, eta=0.0)
......
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
import torch
import numpy as np
import unittest
import tempfile import tempfile
import unittest
from diffusers import GaussianDDPMScheduler, DDIMScheduler import numpy as np
import torch
from diffusers import DDIMScheduler, GaussianDDPMScheduler
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -38,7 +39,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -38,7 +39,7 @@ class SchedulerCommonTest(unittest.TestCase):
image = np.random.rand(batch_size, num_channels, height, width) image = np.random.rand(batch_size, num_channels, height, width)
return torch.tensor(image) return image
@property @property
def dummy_image_deter(self): def dummy_image_deter(self):
...@@ -53,7 +54,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -53,7 +54,7 @@ class SchedulerCommonTest(unittest.TestCase):
image = image / num_elems image = image / num_elems
image = image.transpose(3, 0, 1, 2) image = image.transpose(3, 0, 1, 2)
return torch.tensor(image) return image
def get_scheduler_config(self): def get_scheduler_config(self):
raise NotImplementedError raise NotImplementedError
...@@ -82,7 +83,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -82,7 +83,7 @@ class SchedulerCommonTest(unittest.TestCase):
output = scheduler.step(residual, image, time_step, **kwargs) output = scheduler.step(residual, image, time_step, **kwargs)
new_output = new_scheduler.step(residual, image, time_step, **kwargs) new_output = new_scheduler.step(residual, image, time_step, **kwargs)
assert (output - new_output).abs().sum() < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs): def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
...@@ -103,7 +104,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -103,7 +104,7 @@ class SchedulerCommonTest(unittest.TestCase):
output = scheduler.step(residual, image, time_step, **kwargs) output = scheduler.step(residual, image, time_step, **kwargs)
new_output = new_scheduler.step(residual, image, time_step, **kwargs) new_output = new_scheduler.step(residual, image, time_step, **kwargs)
assert (output - new_output).abs().sum() < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
...@@ -122,7 +123,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -122,7 +123,7 @@ class SchedulerCommonTest(unittest.TestCase):
output = scheduler.step(residual, image, 1, **kwargs) output = scheduler.step(residual, image, 1, **kwargs)
new_output = new_scheduler.step(residual, image, 1, **kwargs) new_output = new_scheduler.step(residual, image, 1, **kwargs)
assert (output - new_output).abs().sum() < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_step_shape(self): def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
...@@ -140,6 +141,26 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -140,6 +141,26 @@ class SchedulerCommonTest(unittest.TestCase):
self.assertEqual(output_0.shape, image.shape) self.assertEqual(output_0.shape, image.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes:
image = self.dummy_image
residual = 0.1 * image
image_pt = torch.tensor(image)
residual_pt = 0.1 * image_pt
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
output = scheduler.step(residual, image, 1, **kwargs)
output_pt = scheduler_pt.step(residual_pt, image_pt, 1, **kwargs)
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-5, "Scheduler outputs are not identical"
class DDPMSchedulerTest(SchedulerCommonTest): class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (GaussianDDPMScheduler,) scheduler_classes = (GaussianDDPMScheduler,)
...@@ -151,7 +172,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -151,7 +172,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
"variance_type": "fixed_small", "variance_type": "fixed_small",
"clip_predicted_image": True "clip_predicted_image": True,
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -186,9 +207,9 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -186,9 +207,9 @@ class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
assert (scheduler.get_variance(0) - 0.0).abs().sum() < 1e-5 assert np.sum(np.abs(scheduler.get_variance(0) - 0.0)) < 1e-5
assert (scheduler.get_variance(487) - 0.00979).abs().sum() < 1e-5 assert np.sum(np.abs(scheduler.get_variance(487) - 0.00979)) < 1e-5
assert (scheduler.get_variance(999) - 0.02).abs().sum() < 1e-5 assert np.sum(np.abs(scheduler.get_variance(999) - 0.02)) < 1e-5
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
...@@ -209,12 +230,12 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -209,12 +230,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
if t > 0: if t > 0:
noise = self.dummy_image_deter noise = self.dummy_image_deter
variance = scheduler.get_variance(t).sqrt() * noise variance = scheduler.get_variance(t) ** (0.5) * noise
image = pred_prev_image + variance image = pred_prev_image + variance
result_sum = image.abs().sum() result_sum = np.sum(np.abs(image))
result_mean = image.abs().mean() result_mean = np.mean(np.abs(image))
assert result_sum.item() - 732.9947 < 1e-3 assert result_sum.item() - 732.9947 < 1e-3
assert result_mean.item() - 0.9544 < 1e-3 assert result_mean.item() - 0.9544 < 1e-3
...@@ -230,7 +251,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -230,7 +251,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
"beta_start": 0.0001, "beta_start": 0.0001,
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
"clip_predicted_image": True "clip_predicted_image": True,
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -269,12 +290,12 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -269,12 +290,12 @@ class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
assert (scheduler.get_variance(0, 50) - 0.0).abs().sum() < 1e-5 assert np.sum(np.abs(scheduler.get_variance(0, 50) - 0.0)) < 1e-5
assert (scheduler.get_variance(21, 50) - 0.14771).abs().sum() < 1e-5 assert np.sum(np.abs(scheduler.get_variance(21, 50) - 0.14771)) < 1e-5
assert (scheduler.get_variance(49, 50) - 0.32460).abs().sum() < 1e-5 assert np.sum(np.abs(scheduler.get_variance(49, 50) - 0.32460)) < 1e-5
assert (scheduler.get_variance(0, 1000) - 0.0).abs().sum() < 1e-5 assert np.sum(np.abs(scheduler.get_variance(0, 1000) - 0.0)) < 1e-5
assert (scheduler.get_variance(487, 1000) - 0.00979).abs().sum() < 1e-5 assert np.sum(np.abs(scheduler.get_variance(487, 1000) - 0.00979)) < 1e-5
assert (scheduler.get_variance(999, 1000) - 0.02).abs().sum() < 1e-5 assert np.sum(np.abs(scheduler.get_variance(999, 1000) - 0.02)) < 1e-5
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
...@@ -297,12 +318,12 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -297,12 +318,12 @@ class DDIMSchedulerTest(SchedulerCommonTest):
variance = 0 variance = 0
if eta > 0: if eta > 0:
noise = self.dummy_image_deter noise = self.dummy_image_deter
variance = scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise
image = pred_prev_image + variance image = pred_prev_image + variance
result_sum = image.abs().sum() result_sum = np.sum(np.abs(image))
result_mean = image.abs().mean() result_mean = np.mean(np.abs(image))
assert result_sum.item() - 270.6214 < 1e-3 assert result_sum.item() - 270.6214 < 1e-3
assert result_mean.item() - 0.3524 < 1e-3 assert result_mean.item() - 0.3524 < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment