Unverified Commit 59900147 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[WIP]Vae preprocessor refactor (PR1) (#3557)



VaeImageProcessor.preprocess refactor

* refactored VaeImageProcessor 
   -  allow passing optional height and width argument to resize()
   - add convert_to_rgb
* refactored prepare_latents method for img2img pipelines so that if we pass latents directly as image input, it will not encode it again
* added a test in test_pipelines_common.py to test latents as image inputs
* refactored img2img pipelines that accept latents as image: 
   - controlnet img2img, stable diffusion img2img , instruct_pix2pix

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 1a6a647e
...@@ -93,6 +93,7 @@ class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTester ...@@ -93,6 +93,7 @@ class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTester
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -47,6 +47,7 @@ class StableDiffusionImageVariationPipelineFastTests( ...@@ -47,6 +47,7 @@ class StableDiffusionImageVariationPipelineFastTests(
batch_params = IMAGE_VARIATION_BATCH_PARAMS batch_params = IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset([]) image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -32,7 +32,6 @@ from diffusers import ( ...@@ -32,7 +32,6 @@ from diffusers import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
...@@ -91,6 +90,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin ...@@ -91,6 +90,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -142,6 +142,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin ...@@ -142,6 +142,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
if str(device).startswith("mps"): if str(device).startswith("mps"):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
else: else:
...@@ -160,12 +161,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin ...@@ -160,12 +161,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components) sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -178,12 +177,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin ...@@ -178,12 +177,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components) sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
negative_prompt = "french fries" negative_prompt = "french fries"
output = sd_pipe(**inputs, negative_prompt=negative_prompt) output = sd_pipe(**inputs, negative_prompt=negative_prompt)
image = output.images image = output.images
...@@ -198,14 +195,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin ...@@ -198,14 +195,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components) sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * 2 inputs["prompt"] = [inputs["prompt"]] * 2
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1) inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images image = sd_pipe(**inputs).images
image_slice = image[-1, -3:, -3:, -1] image_slice = image[-1, -3:, -3:, -1]
...@@ -221,12 +216,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin ...@@ -221,12 +216,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
) )
sd_pipe = StableDiffusionImg2ImgPipeline(**components) sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
......
...@@ -88,6 +88,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipelin ...@@ -88,6 +88,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipelin
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset([]) image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -31,10 +31,15 @@ from diffusers import ( ...@@ -31,10 +31,15 @@ from diffusers import (
StableDiffusionInstructPix2PixPipeline, StableDiffusionInstructPix2PixPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import floats_tensor, load_image, slow, torch_device from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
...@@ -47,9 +52,8 @@ class StableDiffusionInstructPix2PixPipelineFastTests( ...@@ -47,9 +52,8 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
pipeline_class = StableDiffusionInstructPix2PixPipeline pipeline_class = StableDiffusionInstructPix2PixPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"} params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"}
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
[] image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -163,6 +167,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests( ...@@ -163,6 +167,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
image = np.array(inputs["image"]).astype(np.float32) / 255.0 image = np.array(inputs["image"]).astype(np.float32) / 255.0
image = torch.from_numpy(image).unsqueeze(0).to(device) image = torch.from_numpy(image).unsqueeze(0).to(device)
image = image / 2 + 0.5
image = image.permute(0, 3, 1, 2) image = image.permute(0, 3, 1, 2)
inputs["image"] = image.repeat(2, 1, 1, 1) inputs["image"] = image.repeat(2, 1, 1, 1)
...@@ -199,6 +204,28 @@ class StableDiffusionInstructPix2PixPipelineFastTests( ...@@ -199,6 +204,28 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3) super().test_inference_batch_single_identical(expected_max_diff=3e-3)
# Overwrite the default test_latents_inputs because pix2pix encode the image differently
def test_latents_input(self):
components = self.get_dummy_components()
pipe = StableDiffusionInstructPix2PixPipeline(**components)
pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
vae = components["vae"]
inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt")
for image_param in self.image_latents_params:
if image_param in inputs.keys():
inputs[image_param] = vae.encode(inputs[image_param]).latent_dist.mode()
out_latents_inputs = pipe(**inputs)[0]
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -44,6 +44,7 @@ class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, Pi ...@@ -44,6 +44,7 @@ class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, Pi
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -45,6 +45,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli ...@@ -45,6 +45,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -32,11 +32,16 @@ from diffusers import ( ...@@ -32,11 +32,16 @@ from diffusers import (
StableDiffusionPix2PixZeroPipeline, StableDiffusionPix2PixZeroPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, load_image, load_pt, require_torch_gpu, skip_mps from diffusers.utils.testing_utils import enable_full_determinism, load_image, load_pt, require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import (
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference
enable_full_determinism() enable_full_determinism()
...@@ -45,11 +50,10 @@ enable_full_determinism() ...@@ -45,11 +50,10 @@ enable_full_determinism()
@skip_mps @skip_mps
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPix2PixZeroPipeline pipeline_class = StableDiffusionPix2PixZeroPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"image"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset( image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
[] image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -130,6 +134,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip ...@@ -130,6 +134,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
def get_dummy_inversion_inputs(self, device, seed=0): def get_dummy_inversion_inputs(self, device, seed=0):
dummy_image = floats_tensor((2, 3, 32, 32), rng=random.Random(seed)).to(torch_device) dummy_image = floats_tensor((2, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
dummy_image = dummy_image / 2 + 0.5
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
inputs = { inputs = {
...@@ -145,6 +150,24 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip ...@@ -145,6 +150,24 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
} }
return inputs return inputs
def get_dummy_inversion_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
inputs = self.get_dummy_inversion_inputs(device, seed)
if input_image_type == "pt":
image = inputs["image"]
elif input_image_type == "np":
image = VaeImageProcessor.pt_to_numpy(inputs["image"])
elif input_image_type == "pil":
image = VaeImageProcessor.pt_to_numpy(inputs["image"])
image = VaeImageProcessor.numpy_to_pil(image)
else:
raise ValueError(f"unsupported input_image_type {input_image_type}")
inputs["image"] = image
inputs["output_type"] = output_type
return inputs
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
if not hasattr(self.pipeline_class, "_optional_components"): if not hasattr(self.pipeline_class, "_optional_components"):
return return
...@@ -281,6 +304,41 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip ...@@ -281,6 +304,41 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_outputs_equivalent(self):
device = torch_device
components = self.get_dummy_components()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
output_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pt")).images
output_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="np")).images
output_pil = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pil")).images
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`")
max_diff = np.abs(np.array(output_pil[0]) - (output_np[0] * 255).round()).max()
self.assertLess(max_diff, 2.0, "`output_type=='pil'` generate different results from `output_type=='np'`")
def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_inputs_equivalent(self):
device = torch_device
components = self.get_dummy_components()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
out_input_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pt")).images
out_input_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="np")).images
out_input_pil = sd_pipe.invert(
**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pil")
).images
max_diff = np.abs(out_input_pt - out_input_np).max()
self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`")
assert_mean_pixel_difference(out_input_pil, out_input_np, expected_max_diff=1)
# Non-determinism caused by the scheduler optimizing the latent inputs during inference # Non-determinism caused by the scheduler optimizing the latent inputs during inference
@unittest.skip("non-deterministic pipeline") @unittest.skip("non-deterministic pipeline")
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
......
...@@ -41,6 +41,7 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes ...@@ -41,6 +41,7 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_cpu_offload = False test_cpu_offload = False
def get_dummy_components(self): def get_dummy_components(self):
......
...@@ -47,6 +47,7 @@ class StableDiffusion2PipelineFastTests(PipelineLatentTesterMixin, PipelineTeste ...@@ -47,6 +47,7 @@ class StableDiffusion2PipelineFastTests(PipelineLatentTesterMixin, PipelineTeste
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -45,6 +45,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests( ...@@ -45,6 +45,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests(
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"}) batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
# Attend and excite requires being able to run a backward pass at # Attend and excite requires being able to run a backward pass at
# inference time. There's no deterministic backward operator for pad # inference time. There's no deterministic backward operator for pad
......
...@@ -51,7 +51,12 @@ from diffusers.utils import ( ...@@ -51,7 +51,12 @@ from diffusers.utils import (
) )
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
...@@ -65,9 +70,8 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -65,9 +70,8 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
[] image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -49,6 +49,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli ...@@ -49,6 +49,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli
image_params = frozenset( image_params = frozenset(
[] []
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -40,6 +40,7 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineLatentTesterMixin, Pipeli ...@@ -40,6 +40,7 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineLatentTesterMixin, Pipeli
image_params = frozenset( image_params = frozenset(
[] []
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -52,6 +52,7 @@ class StableDiffusionLatentUpscalePipelineFastTests(PipelineLatentTesterMixin, P ...@@ -52,6 +52,7 @@ class StableDiffusionLatentUpscalePipelineFastTests(PipelineLatentTesterMixin, P
image_params = frozenset( image_params = frozenset(
[] []
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
test_cpu_offload = True test_cpu_offload = True
......
...@@ -27,6 +27,7 @@ class StableUnCLIPPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMix ...@@ -27,6 +27,7 @@ class StableUnCLIPPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMix
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
# TODO(will) Expected attn_bias.stride(1) == 0 to be true, but got false # TODO(will) Expected attn_bias.stride(1) == 0 to be true, but got false
test_xformers_attention = False test_xformers_attention = False
......
...@@ -46,6 +46,7 @@ class StableUnCLIPImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTe ...@@ -46,6 +46,7 @@ class StableUnCLIPImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTe
image_params = frozenset( image_params = frozenset(
[] []
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
def get_dummy_components(self): def get_dummy_components(self):
embedder_hidden_size = 32 embedder_hidden_size = 32
......
...@@ -8,6 +8,7 @@ import unittest ...@@ -8,6 +8,7 @@ import unittest
from typing import Callable, Union from typing import Callable, Union
import numpy as np import numpy as np
import PIL
import torch import torch
import diffusers import diffusers
...@@ -39,9 +40,28 @@ class PipelineLatentTesterMixin: ...@@ -39,9 +40,28 @@ class PipelineLatentTesterMixin:
"`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results" "`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results"
) )
@property
def image_latents_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `image_latents_params` in the child test class. "
"`image_latents_params` are tested for if passing latents directly are producing same results"
)
def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"): def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
inputs = self.get_dummy_inputs(device, seed) inputs = self.get_dummy_inputs(device, seed)
def convert_to_pt(image):
if isinstance(image, torch.Tensor):
input_image = image
elif isinstance(image, np.ndarray):
input_image = VaeImageProcessor.numpy_to_pt(image)
elif isinstance(image, PIL.Image.Image):
input_image = VaeImageProcessor.pil_to_numpy(image)
input_image = VaeImageProcessor.numpy_to_pt(input_image)
else:
raise ValueError(f"unsupported input_image_type {type(image)}")
return input_image
def convert_pt_to_type(image, input_image_type): def convert_pt_to_type(image, input_image_type):
if input_image_type == "pt": if input_image_type == "pt":
input_image = image input_image = image
...@@ -56,21 +76,32 @@ class PipelineLatentTesterMixin: ...@@ -56,21 +76,32 @@ class PipelineLatentTesterMixin:
for image_param in self.image_params: for image_param in self.image_params:
if image_param in inputs.keys(): if image_param in inputs.keys():
inputs[image_param] = convert_pt_to_type(inputs[image_param], input_image_type) inputs[image_param] = convert_pt_to_type(
convert_to_pt(inputs[image_param]).to(device), input_image_type
)
inputs["output_type"] = output_type inputs["output_type"] = output_type
return inputs return inputs
def test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4): def test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4):
self._test_pt_np_pil_outputs_equivalent(expected_max_diff=expected_max_diff)
def _test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4, input_image_type="pt"):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
output_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pt"))[0] output_pt = pipe(
output_np = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="np"))[0] **self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pt")
output_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pil"))[0] )[0]
output_np = pipe(
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="np")
)[0]
output_pil = pipe(
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pil")
)[0]
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
self.assertLess( self.assertLess(
...@@ -98,6 +129,31 @@ class PipelineLatentTesterMixin: ...@@ -98,6 +129,31 @@ class PipelineLatentTesterMixin:
max_diff = np.abs(out_input_pil - out_input_np).max() max_diff = np.abs(out_input_pil - out_input_np).max()
self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`") self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`")
def test_latents_input(self):
if len(self.image_latents_params) == 0:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
vae = components["vae"]
inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt")
generator = inputs["generator"]
for image_param in self.image_latents_params:
if image_param in inputs.keys():
inputs[image_param] = (
vae.encode(inputs[image_param]).latent_dist.sample(generator) * vae.config.scaling_factor
)
out_latents_inputs = pipe(**inputs)[0]
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
@require_torch @require_torch
class PipelineTesterMixin: class PipelineTesterMixin:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment