Commit 73354621 authored by patil-suraj's avatar patil-suraj
Browse files

fix tests

parent d8287fcd
...@@ -218,9 +218,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -218,9 +218,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(669472945848556) generator = generator.manual_seed(669472945848556)
image = ddpm(generator) image = ddpm(generator=generator)
generator = generator.manual_seed(669472945848556) generator = generator.manual_seed(669472945848556)
new_image = new_ddpm(generator) new_image = new_ddpm(generator=generator)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
...@@ -239,8 +239,8 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -239,8 +239,8 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.Generator(device=torch_device) generator = torch.Generator(device=torch_device)
generator = generator.manual_seed(669472945848556) generator = generator.manual_seed(669472945848556)
image = ddpm(generator) image = ddpm(generator=generator)
generator = generator.manual_seed(669472945848556) generator = generator.manual_seed(669472945848556)
new_image = ddpm_from_hub(generator) new_image = ddpm_from_hub(generator=generator)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment