Commit 43bbc781 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

adapt test

parent 1c14ce95
...@@ -36,6 +36,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -36,6 +36,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm.tqdm(self.scheduler.timesteps): for t in tqdm.tqdm(self.scheduler.timesteps):
with torch.no_grad():
model_output = self.unet(image, t) model_output = self.unet(image, t)
if isinstance(model_output, dict): if isinstance(model_output, dict):
...@@ -46,5 +47,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -46,5 +47,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"] image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
# decode image with vae # decode image with vae
with torch.no_grad():
image = self.vqvae.decode(image) image = self.vqvae.decode(image)
return {"sample": image} return {"sample": image}
...@@ -1070,7 +1070,8 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1070,7 +1070,8 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_uncond(self): def test_ldm_uncond(self):
ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256", ldm=True) # ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256", ldm=True)
ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/latent-diffusion-celeba-256")
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5)["sample"] image = ldm(generator=generator, num_inference_steps=5)["sample"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment