Commit b76eea04 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

check with other device

parent 5da71f8f
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import random import random
import tempfile import tempfile
import unittest import unittest
...@@ -22,7 +21,6 @@ import os ...@@ -22,7 +21,6 @@ import os
from distutils.util import strtobool from distutils.util import strtobool
import torch import torch
import numpy as np
from diffusers import GaussianDDPMScheduler, UNetModel from diffusers import GaussianDDPMScheduler, UNetModel
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
...@@ -31,22 +29,7 @@ from models.vision.ddpm.modeling_ddpm import DDPM ...@@ -31,22 +29,7 @@ from models.vision.ddpm.modeling_ddpm import DDPM
global_rng = random.Random() global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = False
def get_random_generator(seed):
seed = 1234
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
generator = torch.Generator()
return generator
def parse_flag_from_env(key, default=False): def parse_flag_from_env(key, default=False):
...@@ -132,7 +115,7 @@ class SamplerTesterMixin(unittest.TestCase): ...@@ -132,7 +115,7 @@ class SamplerTesterMixin(unittest.TestCase):
@slow @slow
def test_sample(self): def test_sample(self):
generator = get_random_generator(0) generator = torch.manual_seed(0)
# 1. Load models # 1. Load models
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church") scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
...@@ -182,13 +165,12 @@ class SamplerTesterMixin(unittest.TestCase): ...@@ -182,13 +165,12 @@ class SamplerTesterMixin(unittest.TestCase):
def test_sample_fast(self): def test_sample_fast(self):
# 1. Load models # 1. Load models
generator = get_random_generator(0) generator = torch.manual_seed(0)
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10) scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10)
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise # 2. Sample gaussian noise
torch.manual_seed(0)
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
# 3. Denoise # 3. Denoise
...@@ -218,8 +200,8 @@ class SamplerTesterMixin(unittest.TestCase): ...@@ -218,8 +200,8 @@ class SamplerTesterMixin(unittest.TestCase):
assert image.shape == (1, 3, 256, 256) assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
import ipdb; ipdb.set_trace() expected_slice = torch.tensor([-0.0304, -0.1895, -0.2436, -0.9837, -0.5422, 0.1931, -0.8175, 0.0862, -0.7783])
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3 assert (image_slice.flatten() - expected_slice).abs().sum() < 1e-3
class PipelineTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment