Commit c5c93996 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

correct paths for tests

parent 836f3f35
...@@ -365,9 +365,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -365,9 +365,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained( model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
"/home/patrick/google_checkpoints/unet-ldm-dummy-update", output_loading_info=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -378,7 +376,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -378,7 +376,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/unet-ldm-dummy-update") model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
model.eval() model.eval()
torch.manual_seed(0) torch.manual_seed(0)
...@@ -472,9 +470,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -472,9 +470,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained( model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
"/home/patrick/google_checkpoints/ncsnpp-celebahq-256", output_loading_info=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -487,7 +483,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -487,7 +483,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained_ve_mid(self): def test_output_pretrained_ve_mid(self):
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-celebahq-256") model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -512,7 +508,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -512,7 +508,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_large(self): def test_output_pretrained_ve_large(self):
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-ffhq-ve-dummy-update") model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -582,9 +578,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -582,9 +578,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
pass pass
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = VQModel.from_pretrained( model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
"/home/patrick/google_checkpoints/vqgan-dummy", output_loading_info=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -594,7 +588,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -594,7 +588,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = VQModel.from_pretrained("/home/patrick/google_checkpoints/vqgan-dummy") model = VQModel.from_pretrained("fusing/vqgan-dummy")
model.eval() model.eval()
torch.manual_seed(0) torch.manual_seed(0)
...@@ -655,9 +649,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -655,9 +649,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
pass pass
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained( model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
"/home/patrick/google_checkpoints/autoencoder-kl-dummy", output_loading_info=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -667,7 +659,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -667,7 +659,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = AutoencoderKL.from_pretrained("/home/patrick/google_checkpoints/autoencoder-kl-dummy") model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
model.eval() model.eval()
torch.manual_seed(0) torch.manual_seed(0)
...@@ -715,7 +707,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -715,7 +707,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32" model_path = "google/ddpm-cifar10-32"
ddpm = DDPMPipeline.from_pretrained(model_path) ddpm = DDPMPipeline.from_pretrained(model_path)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
...@@ -733,7 +725,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -733,7 +725,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_output_format(self): def test_output_format(self):
model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32" model_path = "google/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path) pipe = DDIMPipeline.from_pretrained(model_path)
...@@ -754,7 +746,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -754,7 +746,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ddpm_cifar10(self): def test_ddpm_cifar10(self):
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_config(model_id) scheduler = DDPMScheduler.from_config(model_id)
...@@ -773,7 +765,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -773,7 +765,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ddim_lsun(self): def test_ddim_lsun(self):
model_id = "/home/patrick/google_checkpoints/ddpm-ema-bedroom-256" model_id = "google/ddpm-ema-bedroom-256"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler.from_config(model_id) scheduler = DDIMScheduler.from_config(model_id)
...@@ -791,7 +783,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -791,7 +783,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ddim_cifar10(self): def test_ddim_cifar10(self):
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler(tensor_format="pt") scheduler = DDIMScheduler(tensor_format="pt")
...@@ -809,7 +801,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -809,7 +801,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_pndm_cifar10(self): def test_pndm_cifar10(self):
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = PNDMScheduler(tensor_format="pt") scheduler = PNDMScheduler(tensor_format="pt")
...@@ -826,7 +818,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -826,7 +818,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_text2img(self): def test_ldm_text2img(self):
ldm = LDMTextToImagePipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256") ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -842,7 +834,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -842,7 +834,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_text2img_fast(self): def test_ldm_text2img_fast(self):
ldm = LDMTextToImagePipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256") ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -856,13 +848,13 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -856,13 +848,13 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-church-256") model = UNet2DModel.from_pretrained("google/ncsnpp-church-256")
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
scheduler = ScoreSdeVeScheduler.from_config("/home/patrick/google_checkpoints/ncsnpp-church-256") scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
...@@ -877,7 +869,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -877,7 +869,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_uncond(self): def test_ldm_uncond(self):
ldm = LDMPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256") ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"] image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment