Unverified Commit c6d0dff4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix ldm tests on master by not running the CPU tests on GPU (#1729)

parent a40095dd
...@@ -72,6 +72,9 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase): ...@@ -72,6 +72,9 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
return CLIPTextModel(config) return CLIPTextModel(config)
def test_inference_text2img(self): def test_inference_text2img(self):
if torch_device != "cpu":
return
unet = self.dummy_cond_unet unet = self.dummy_cond_unet
scheduler = DDIMScheduler() scheduler = DDIMScheduler()
vae = self.dummy_vae vae = self.dummy_vae
...@@ -91,12 +94,16 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase): ...@@ -91,12 +94,16 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy" [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy"
).images ).images
generator = torch.manual_seed(0) device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image = ldm( image = ldm(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy" [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy"
).images ).images
generator = torch.manual_seed(0) device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ldm( image_from_tuple = ldm(
[prompt], [prompt],
generator=generator, generator=generator,
...@@ -124,7 +131,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase): ...@@ -124,7 +131,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image = ldm( image = ldm(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy" [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy"
).images ).images
...@@ -141,7 +151,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase): ...@@ -141,7 +151,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment