Commit 577a6a65 authored by anton-l's avatar anton-l
Browse files

Fix SD tests .to(device)

parent 21ceda3f
......@@ -866,7 +866,7 @@ class PipelineTesterMixin(unittest.TestCase):
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1").to(torch_device)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
......@@ -886,7 +886,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_fast_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1").to(torch_device)
scheduler = DDIMScheduler(
beta_start=0.00085,
......@@ -1003,8 +1003,8 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1-diffusers"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(torch_device)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
pipe.scheduler = scheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment