Unverified Commit a7f25b4a authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Postprocessing refactor img2img (#3268)



* refactor img2img VaeImageProcessor.postprocess

* remove copy from for init, run_safety_checker, decode_latents
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

---------
Co-authored-by: default avataryiyixuxu <yixu@yis-macbook-pro.lan>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0e82fb19
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
from typing import Union from typing import List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
...@@ -21,7 +21,7 @@ import torch ...@@ -21,7 +21,7 @@ import torch
from PIL import Image from PIL import Image
from .configuration_utils import ConfigMixin, register_to_config from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
class VaeImageProcessor(ConfigMixin): class VaeImageProcessor(ConfigMixin):
...@@ -82,7 +82,7 @@ class VaeImageProcessor(ConfigMixin): ...@@ -82,7 +82,7 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod @staticmethod
def pt_to_numpy(images): def pt_to_numpy(images):
""" """
Convert a numpy image to a pytorch tensor Convert a pytorch tensor to a numpy image
""" """
images = images.cpu().permute(0, 2, 3, 1).float().numpy() images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images return images
...@@ -94,6 +94,13 @@ class VaeImageProcessor(ConfigMixin): ...@@ -94,6 +94,13 @@ class VaeImageProcessor(ConfigMixin):
""" """
return 2.0 * images - 1.0 return 2.0 * images - 1.0
@staticmethod
def denormalize(images):
"""
Denormalize an image array to [0,1]
"""
return (images / 2 + 0.5).clamp(0, 1)
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
""" """
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
...@@ -165,17 +172,39 @@ class VaeImageProcessor(ConfigMixin): ...@@ -165,17 +172,39 @@ class VaeImageProcessor(ConfigMixin):
def postprocess( def postprocess(
self, self,
image, image: torch.FloatTensor,
output_type: str = "pil", output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
): ):
if isinstance(image, torch.Tensor) and output_type == "pt": if not isinstance(image, torch.Tensor):
raise ValueError(
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
)
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"
if output_type == "latent":
return image
if do_denormalize is None:
do_denormalize = [self.config.do_normalize] * image.shape[0]
image = torch.stack(
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
)
if output_type == "pt":
return image return image
image = self.pt_to_numpy(image) image = self.pt_to_numpy(image)
if output_type == "np": if output_type == "np":
return image return image
elif output_type == "pil":
if output_type == "pil":
return self.numpy_to_pil(image) return self.numpy_to_pil(image)
else:
raise ValueError(f"Unsupported output_type {output_type}.")
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
...@@ -202,6 +203,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -202,6 +203,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
new_config = dict(unet.config) new_config = dict(unet.config)
new_config["sample_size"] = 64 new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config) unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -212,11 +214,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -212,11 +214,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config( self.register_to_config(requires_safety_checker=requires_safety_checker)
requires_safety_checker=requires_safety_checker,
)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
...@@ -436,17 +435,32 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -436,17 +435,32 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
return prompt_embeds return prompt_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") if self.safety_checker is None:
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) has_nsfw_concept = None
image, has_nsfw_concept = self.safety_checker( else:
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) if torch.is_tensor(image):
) feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept return image, has_nsfw_concept
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead"
),
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image return image
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
...@@ -730,27 +744,19 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -730,27 +744,19 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
if output_type not in ["latent", "pt", "np", "pil"]: if not output_type == "latent":
deprecation_message = ( image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
"`pil`, `np`, `pt`, `latent`" else:
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"
if output_type == "latent":
image = latents image = latents
has_nsfw_concept = None has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else: else:
image = self.decode_latents(latents) do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
...@@ -205,6 +206,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -205,6 +206,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
new_config = dict(unet.config) new_config = dict(unet.config)
new_config["sample_size"] = 64 new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config) unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -215,11 +217,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -215,11 +217,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config( self.register_to_config(requires_safety_checker=requires_safety_checker)
requires_safety_checker=requires_safety_checker,
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
...@@ -443,17 +442,30 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -443,17 +442,30 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
return prompt_embeds return prompt_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") if self.safety_checker is None:
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) has_nsfw_concept = None
image, has_nsfw_concept = self.safety_checker( else:
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) if torch.is_tensor(image):
) feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept return image, has_nsfw_concept
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -738,27 +750,19 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -738,27 +750,19 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
if output_type not in ["latent", "pt", "np", "pil"]: if not output_type == "latent":
deprecation_message = ( image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
"`pil`, `np`, `pt`, `latent`" else:
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"
if output_type == "latent":
image = latents image = latents
has_nsfw_concept = None has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else: else:
image = self.decode_latents(latents) do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -42,7 +42,7 @@ class ImageProcessorTest(unittest.TestCase): ...@@ -42,7 +42,7 @@ class ImageProcessorTest(unittest.TestCase):
return image return image
def test_vae_image_processor_pt(self): def test_vae_image_processor_pt(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
input_pt = self.dummy_sample input_pt = self.dummy_sample
input_np = self.to_np(input_pt) input_np = self.to_np(input_pt)
...@@ -59,7 +59,7 @@ class ImageProcessorTest(unittest.TestCase): ...@@ -59,7 +59,7 @@ class ImageProcessorTest(unittest.TestCase):
), f"decoded output does not match input for output_type {output_type}" ), f"decoded output does not match input for output_type {output_type}"
def test_vae_image_processor_np(self): def test_vae_image_processor_np(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
for output_type in ["pt", "np", "pil"]: for output_type in ["pt", "np", "pil"]:
...@@ -72,7 +72,7 @@ class ImageProcessorTest(unittest.TestCase): ...@@ -72,7 +72,7 @@ class ImageProcessorTest(unittest.TestCase):
), f"decoded output does not match input for output_type {output_type}" ), f"decoded output does not match input for output_type {output_type}"
def test_vae_image_processor_pil(self): def test_vae_image_processor_pil(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
input_pil = image_processor.numpy_to_pil(input_np) input_pil = image_processor.numpy_to_pil(input_np)
......
...@@ -22,6 +22,10 @@ TEXT_TO_IMAGE_PARAMS = frozenset( ...@@ -22,6 +22,10 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
IMAGE_VARIATION_PARAMS = frozenset( IMAGE_VARIATION_PARAMS = frozenset(
[ [
"image", "image",
......
...@@ -35,18 +35,23 @@ from diffusers.image_processor import VaeImageProcessor ...@@ -35,18 +35,23 @@ from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import (
from ..test_pipelines_common import PipelineTesterMixin IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionImg2ImgPipeline pipeline_class = StableDiffusionImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -96,33 +101,19 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -96,33 +101,19 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
} }
return components return components
def get_dummy_inputs(self, device, seed=0, input_image_type="pt", output_type="np"): def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
if str(device).startswith("mps"): if str(device).startswith("mps"):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
else: else:
generator = torch.Generator(device=device).manual_seed(seed) generator = torch.Generator(device=device).manual_seed(seed)
if input_image_type == "pt":
input_image = image
elif input_image_type == "np":
input_image = image.cpu().numpy().transpose(0, 2, 3, 1)
elif input_image_type == "pil":
input_image = image.cpu().numpy().transpose(0, 2, 3, 1)
input_image = VaeImageProcessor.numpy_to_pil(input_image)
else:
raise ValueError(f"unsupported input_image_type {input_image_type}.")
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"unsupported output_type {output_type}")
inputs = { inputs = {
"prompt": "A painting of a squirrel eating a burger", "prompt": "A painting of a squirrel eating a burger",
"image": input_image, "image": image,
"generator": generator, "generator": generator,
"num_inference_steps": 2, "num_inference_steps": 2,
"guidance_scale": 6.0, "guidance_scale": 6.0,
"output_type": output_type, "output_type": "numpy",
} }
return inputs return inputs
...@@ -130,11 +121,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -130,11 +121,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components) sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -147,11 +139,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -147,11 +139,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components) sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
negative_prompt = "french fries" negative_prompt = "french fries"
output = sd_pipe(**inputs, negative_prompt=negative_prompt) output = sd_pipe(**inputs, negative_prompt=negative_prompt)
image = output.images image = output.images
...@@ -166,13 +159,14 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -166,13 +159,14 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components) sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * 2 inputs["prompt"] = [inputs["prompt"]] * 2
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1) inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images image = sd_pipe(**inputs).images
image_slice = image[-1, -3:, -3:, -1] image_slice = image[-1, -3:, -3:, -1]
...@@ -188,11 +182,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -188,11 +182,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
) )
sd_pipe = StableDiffusionImg2ImgPipeline(**components) sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -217,36 +212,6 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -217,36 +212,6 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
return super().test_attention_slicing_forward_pass() return super().test_attention_slicing_forward_pass()
@skip_mps
def test_pt_np_pil_outputs_equivalent(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
output_pt = sd_pipe(**self.get_dummy_inputs(device, output_type="pt"))[0]
output_np = sd_pipe(**self.get_dummy_inputs(device, output_type="np"))[0]
output_pil = sd_pipe(**self.get_dummy_inputs(device, output_type="pil"))[0]
assert np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() <= 1e-4
assert np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() <= 1e-4
@skip_mps
def test_image_types_consistent(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
output_pt = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pt"))[0]
output_np = sd_pipe(**self.get_dummy_inputs(device, input_image_type="np"))[0]
output_pil = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pil"))[0]
assert np.abs(output_pt - output_np).max() <= 1e-4
assert np.abs(output_pil - output_np).max() <= 1e-2
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
import diffusers import diffusers
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
from diffusers.utils.testing_utils import require_torch, torch_device from diffusers.utils.testing_utils import require_torch, torch_device
...@@ -27,6 +28,78 @@ def to_np(tensor): ...@@ -27,6 +28,78 @@ def to_np(tensor):
return tensor return tensor
class PipelineLatentTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for PyTorch pipeline that has vae, e.g.
equivalence of different input and output types, etc.
"""
@property
def image_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `image_params` in the child test class. "
"`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results"
)
def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
inputs = self.get_dummy_inputs(device, seed)
def convert_pt_to_type(image, input_image_type):
if input_image_type == "pt":
input_image = image
elif input_image_type == "np":
input_image = VaeImageProcessor.pt_to_numpy(image)
elif input_image_type == "pil":
input_image = VaeImageProcessor.pt_to_numpy(image)
input_image = VaeImageProcessor.numpy_to_pil(input_image)
else:
raise ValueError(f"unsupported input_image_type {input_image_type}.")
return input_image
for image_param in self.image_params:
if image_param in inputs.keys():
inputs[image_param] = convert_pt_to_type(inputs[image_param], input_image_type)
inputs["output_type"] = output_type
return inputs
def test_pt_np_pil_outputs_equivalent(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
output_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pt"))[0]
output_np = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="np"))[0]
output_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pil"))[0]
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`")
max_diff = np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max()
self.assertLess(max_diff, 1e-4, "`output_type=='pil'` generate different results from `output_type=='np'`")
def test_pt_np_pil_inputs_equivalent(self):
if len(self.image_params) == 0:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
out_input_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
out_input_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
out_input_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pil"))[0]
max_diff = np.abs(out_input_pt - out_input_np).max()
self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`")
max_diff = np.abs(out_input_pil - out_input_np).max()
self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`")
@require_torch @require_torch
class PipelineTesterMixin: class PipelineTesterMixin:
""" """
...@@ -339,9 +412,6 @@ class PipelineTesterMixin: ...@@ -339,9 +412,6 @@ class PipelineTesterMixin:
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_float16_inference(self): def test_float16_inference(self):
self._test_float16_inference()
def _test_float16_inference(self, expected_max_diff=1e-2):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe.to(torch_device) pipe.to(torch_device)
...@@ -355,13 +425,10 @@ class PipelineTesterMixin: ...@@ -355,13 +425,10 @@ class PipelineTesterMixin:
output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0] output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0]
max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.") self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.")
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_save_load_float16(self): def test_save_load_float16(self):
self._test_save_load_float16()
def _test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components() components = self.get_dummy_components()
for name, module in components.items(): for name, module in components.items():
if hasattr(module, "half"): if hasattr(module, "half"):
...@@ -390,9 +457,7 @@ class PipelineTesterMixin: ...@@ -390,9 +457,7 @@ class PipelineTesterMixin:
output_loaded = pipe_loaded(**inputs)[0] output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess( self.assertLess(max_diff, 1e-2, "The output of the fp16 pipeline changed after saving and loading.")
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
if not hasattr(self.pipeline_class, "_optional_components"): if not hasattr(self.pipeline_class, "_optional_components"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment