Unverified Commit 1586186e authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

pix2pix tests no write to fs (#2497)

* attend and excite batch test causing timeouts

* pix2pix tests, no write to fs
parent 42beaf1d
...@@ -209,6 +209,13 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) - ...@@ -209,6 +209,13 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
return arry return arry
def load_pt(url: str):
response = requests.get(url)
response.raise_for_status()
arry = torch.load(BytesIO(response.content))
return arry
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
""" """
Args: Args:
......
...@@ -17,9 +17,7 @@ import gc ...@@ -17,9 +17,7 @@ import gc
import unittest import unittest
import numpy as np import numpy as np
import requests
import torch import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
...@@ -33,7 +31,7 @@ from diffusers import ( ...@@ -33,7 +31,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import load_numpy, slow, torch_device from diffusers.utils import load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -41,16 +39,20 @@ from ...test_pipelines_common import PipelineTesterMixin ...@@ -41,16 +39,20 @@ from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
def download_from_url(embedding_url, local_filepath):
r = requests.get(embedding_url)
with open(local_filepath, "wb") as f:
f.write(r.content)
@skip_mps @skip_mps
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPix2PixZeroPipeline pipeline_class = StableDiffusionPix2PixZeroPipeline
@classmethod
def setUpClass(cls):
cls.source_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/src_emb_0.pt"
)
cls.target_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/tgt_emb_0.pt"
)
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
...@@ -103,15 +105,6 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest. ...@@ -103,15 +105,6 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
return components return components
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/src_emb_0.pt"
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/tgt_emb_0.pt"
for url in [src_emb_url, tgt_emb_url]:
download_from_url(url, url.split("/")[-1])
src_embeds = torch.load(src_emb_url.split("/")[-1])
target_embeds = torch.load(tgt_emb_url.split("/")[-1])
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
inputs = { inputs = {
...@@ -120,8 +113,8 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest. ...@@ -120,8 +113,8 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
"num_inference_steps": 2, "num_inference_steps": 2,
"guidance_scale": 6.0, "guidance_scale": 6.0,
"cross_attention_guidance_amount": 0.15, "cross_attention_guidance_amount": 0.15,
"source_embeds": src_embeds, "source_embeds": self.source_embeds,
"target_embeds": target_embeds, "target_embeds": self.target_embeds,
"output_type": "numpy", "output_type": "numpy",
} }
return inputs return inputs
...@@ -237,17 +230,18 @@ class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase): ...@@ -237,17 +230,18 @@ class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_inputs(self, seed=0): @classmethod
generator = torch.manual_seed(seed) def setUpClass(cls):
cls.source_embeds = load_pt(
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat.pt"
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" )
for url in [src_emb_url, tgt_emb_url]: cls.target_embeds = load_pt(
download_from_url(url, url.split("/")[-1]) "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.pt"
)
src_embeds = torch.load(src_emb_url.split("/")[-1]) def get_inputs(self, seed=0):
target_embeds = torch.load(tgt_emb_url.split("/")[-1]) generator = torch.manual_seed(seed)
inputs = { inputs = {
"prompt": "turn him into a cyborg", "prompt": "turn him into a cyborg",
...@@ -255,8 +249,8 @@ class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase): ...@@ -255,8 +249,8 @@ class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase):
"num_inference_steps": 3, "num_inference_steps": 3,
"guidance_scale": 7.5, "guidance_scale": 7.5,
"cross_attention_guidance_amount": 0.15, "cross_attention_guidance_amount": 0.15,
"source_embeds": src_embeds, "source_embeds": self.source_embeds,
"target_embeds": target_embeds, "target_embeds": self.target_embeds,
"output_type": "numpy", "output_type": "numpy",
} }
return inputs return inputs
...@@ -364,10 +358,17 @@ class InversionPipelineSlowTests(unittest.TestCase): ...@@ -364,10 +358,17 @@ class InversionPipelineSlowTests(unittest.TestCase):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def test_stable_diffusion_pix2pix_inversion(self): @classmethod
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" def setUpClass(cls):
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) raw_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png"
)
raw_image = raw_image.convert("RGB").resize((512, 512))
cls.raw_image = raw_image
def test_stable_diffusion_pix2pix_inversion(self):
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
) )
...@@ -380,7 +381,7 @@ class InversionPipelineSlowTests(unittest.TestCase): ...@@ -380,7 +381,7 @@ class InversionPipelineSlowTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
output = pipe.invert(caption, image=raw_image, generator=generator, num_inference_steps=10) output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
inv_latents = output[0] inv_latents = output[0]
image_slice = inv_latents[0, -3:, -3:, -1].flatten() image_slice = inv_latents[0, -3:, -3:, -1].flatten()
...@@ -391,9 +392,6 @@ class InversionPipelineSlowTests(unittest.TestCase): ...@@ -391,9 +392,6 @@ class InversionPipelineSlowTests(unittest.TestCase):
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3 assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3
def test_stable_diffusion_pix2pix_full(self): def test_stable_diffusion_pix2pix_full(self):
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png # numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
expected_image = load_numpy( expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy"
...@@ -411,7 +409,7 @@ class InversionPipelineSlowTests(unittest.TestCase): ...@@ -411,7 +409,7 @@ class InversionPipelineSlowTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
output = pipe.invert(caption, image=raw_image, generator=generator) output = pipe.invert(caption, image=self.raw_image, generator=generator)
inv_latents = output[0] inv_latents = output[0]
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment