Unverified Commit 5c7a35a2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Torch 2.0 compile] Fix more torch compile breaks (#3313)



* Fix more torch compile breaks

* add tests

* Fix all

* fix controlnet

* fix more

* Add Horace He as co-author.
>
>
Co-authored-by: default avatarHorace He <horacehe2007@yahoo.com>

* Add Horace He as co-author.
Co-authored-by: default avatarHorace He <horacehe2007@yahoo.com>

---------
Co-authored-by: default avatarHorace He <horacehe2007@yahoo.com>
parent a7f25b4a
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from packaging import version
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
...@@ -460,6 +461,28 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase): ...@@ -460,6 +461,28 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}" assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
def test_img2img_compile(self):
if version.parse(torch.__version__) < version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 768, 3)
expected_slice = np.array([0.0593, 0.0607, 0.0851, 0.0582, 0.0636, 0.0721, 0.0751, 0.0981, 0.0781])
assert np.abs(expected_slice - image_slice).max() < 1e-3
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from packaging import version
from PIL import Image from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
...@@ -274,6 +275,31 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase): ...@@ -274,6 +275,31 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
# make sure that less than 2.2 GB is allocated # make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9 assert mem_bytes < 2.2 * 10**9
def test_inpaint_compile(self):
if version.parse(torch.__version__) < version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", safety_checker=None
)
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])
assert np.abs(expected_slice - image_slice).max() < 1e-4
assert np.abs(expected_slice - image_slice).max() < 1e-3
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment