Unverified Commit 5c7a35a2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Torch 2.0 compile] Fix more torch compile breaks (#3313)



* Fix more torch compile breaks

* add tests

* Fix all

* fix controlnet

* fix more

* Add Horace He as co-author.
>
>
Co-authored-by: default avatarHorace He <horacehe2007@yahoo.com>

* Add Horace He as co-author.
Co-authored-by: default avatarHorace He <horacehe2007@yahoo.com>

---------
Co-authored-by: default avatarHorace He <horacehe2007@yahoo.com>
parent a7f25b4a
......@@ -19,6 +19,7 @@ import unittest
import numpy as np
import torch
from packaging import version
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
......@@ -460,6 +461,28 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
def test_img2img_compile(self):
if version.parse(torch.__version__) < version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 768, 3)
expected_slice = np.array([0.0593, 0.0607, 0.0851, 0.0582, 0.0636, 0.0721, 0.0751, 0.0981, 0.0781])
assert np.abs(expected_slice - image_slice).max() < 1e-3
@nightly
@require_torch_gpu
......
......@@ -19,6 +19,7 @@ import unittest
import numpy as np
import torch
from packaging import version
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
......@@ -274,6 +275,31 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
# make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9
def test_inpaint_compile(self):
if version.parse(torch.__version__) < version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", safety_checker=None
)
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])
assert np.abs(expected_slice - image_slice).max() < 1e-4
assert np.abs(expected_slice - image_slice).max() < 1e-3
@nightly
@require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment