Unverified Commit 8be48507 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

fix test_components (#928)

parent 4bf675f4
...@@ -1405,7 +1405,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1405,7 +1405,7 @@ class PipelineFastTests(unittest.TestCase):
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# make sure here that pndm scheduler skips prk # make sure here that pndm scheduler skips prk
inpaint = StableDiffusionInpaintPipeline( inpaint = StableDiffusionInpaintPipelineLegacy(
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
vae=vae, vae=vae,
...@@ -1413,9 +1413,9 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1413,9 +1413,9 @@ class PipelineFastTests(unittest.TestCase):
tokenizer=tokenizer, tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker, safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
) ).to(torch_device)
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components) img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
text2img = StableDiffusionPipeline(**inpaint.components) text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment