Unverified Commit 59900147 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[WIP]Vae preprocessor refactor (PR1) (#3557)



VaeImageProcessor.preprocess refactor

* refactored VaeImageProcessor 
   -  allow passing optional height and width argument to resize()
   - add convert_to_rgb
* refactored prepare_latents method for img2img pipelines so that if we pass latents directly as image input, it will not encode it again
* added a test in test_pipelines_common.py to test latents as image inputs
* refactored img2img pipelines that accept latents as image: 
   - controlnet img2img, stable diffusion img2img , instruct_pix2pix

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 1a6a647e
......@@ -93,6 +93,7 @@ class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTester
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -47,6 +47,7 @@ class StableDiffusionImageVariationPipelineFastTests(
batch_params = IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -32,7 +32,6 @@ from diffusers import (
StableDiffusionImg2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import (
enable_full_determinism,
......@@ -91,6 +90,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......@@ -142,6 +142,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
......@@ -160,12 +161,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
......@@ -178,12 +177,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
negative_prompt = "french fries"
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
image = output.images
......@@ -198,14 +195,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * 2
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images
image_slice = image[-1, -3:, -3:, -1]
......@@ -221,12 +216,10 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
......
......@@ -88,6 +88,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipelin
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -31,10 +31,15 @@ from diffusers import (
StableDiffusionInstructPix2PixPipeline,
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -47,9 +52,8 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
pipeline_class = StableDiffusionInstructPix2PixPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"}
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......@@ -163,6 +167,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
image = np.array(inputs["image"]).astype(np.float32) / 255.0
image = torch.from_numpy(image).unsqueeze(0).to(device)
image = image / 2 + 0.5
image = image.permute(0, 3, 1, 2)
inputs["image"] = image.repeat(2, 1, 1, 1)
......@@ -199,6 +204,28 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
# Overwrite the default test_latents_inputs because pix2pix encode the image differently
def test_latents_input(self):
components = self.get_dummy_components()
pipe = StableDiffusionInstructPix2PixPipeline(**components)
pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
vae = components["vae"]
inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt")
for image_param in self.image_latents_params:
if image_param in inputs.keys():
inputs[image_param] = vae.encode(inputs[image_param]).latent_dist.mode()
out_latents_inputs = pipe(**inputs)[0]
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
@slow
@require_torch_gpu
......
......@@ -44,6 +44,7 @@ class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, Pi
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -45,6 +45,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -32,11 +32,16 @@ from diffusers import (
StableDiffusionPix2PixZeroPipeline,
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, load_image, load_pt, require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference
enable_full_determinism()
......@@ -45,11 +50,10 @@ enable_full_determinism()
@skip_mps
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPix2PixZeroPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"image"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
@classmethod
def setUpClass(cls):
......@@ -130,6 +134,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
def get_dummy_inversion_inputs(self, device, seed=0):
dummy_image = floats_tensor((2, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
dummy_image = dummy_image / 2 + 0.5
generator = torch.manual_seed(seed)
inputs = {
......@@ -145,6 +150,24 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
}
return inputs
def get_dummy_inversion_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
inputs = self.get_dummy_inversion_inputs(device, seed)
if input_image_type == "pt":
image = inputs["image"]
elif input_image_type == "np":
image = VaeImageProcessor.pt_to_numpy(inputs["image"])
elif input_image_type == "pil":
image = VaeImageProcessor.pt_to_numpy(inputs["image"])
image = VaeImageProcessor.numpy_to_pil(image)
else:
raise ValueError(f"unsupported input_image_type {input_image_type}")
inputs["image"] = image
inputs["output_type"] = output_type
return inputs
def test_save_load_optional_components(self):
if not hasattr(self.pipeline_class, "_optional_components"):
return
......@@ -281,6 +304,41 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_outputs_equivalent(self):
device = torch_device
components = self.get_dummy_components()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
output_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pt")).images
output_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="np")).images
output_pil = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pil")).images
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`")
max_diff = np.abs(np.array(output_pil[0]) - (output_np[0] * 255).round()).max()
self.assertLess(max_diff, 2.0, "`output_type=='pil'` generate different results from `output_type=='np'`")
def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_inputs_equivalent(self):
device = torch_device
components = self.get_dummy_components()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
out_input_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pt")).images
out_input_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="np")).images
out_input_pil = sd_pipe.invert(
**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pil")
).images
max_diff = np.abs(out_input_pt - out_input_np).max()
self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`")
assert_mean_pixel_difference(out_input_pil, out_input_np, expected_max_diff=1)
# Non-determinism caused by the scheduler optimizing the latent inputs during inference
@unittest.skip("non-deterministic pipeline")
def test_inference_batch_single_identical(self):
......
......@@ -41,6 +41,7 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_cpu_offload = False
def get_dummy_components(self):
......
......@@ -47,6 +47,7 @@ class StableDiffusion2PipelineFastTests(PipelineLatentTesterMixin, PipelineTeste
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -45,6 +45,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests(
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
# Attend and excite requires being able to run a backward pass at
# inference time. There's no deterministic backward operator for pad
......
......@@ -51,7 +51,12 @@ from diffusers.utils import (
)
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -65,9 +70,8 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -49,6 +49,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -40,6 +40,7 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineLatentTesterMixin, Pipeli
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -52,6 +52,7 @@ class StableDiffusionLatentUpscalePipelineFastTests(PipelineLatentTesterMixin, P
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
test_cpu_offload = True
......
......@@ -27,6 +27,7 @@ class StableUnCLIPPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMix
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
# TODO(will) Expected attn_bias.stride(1) == 0 to be true, but got false
test_xformers_attention = False
......
......@@ -46,6 +46,7 @@ class StableUnCLIPImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTe
image_params = frozenset(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
def get_dummy_components(self):
embedder_hidden_size = 32
......
......@@ -8,6 +8,7 @@ import unittest
from typing import Callable, Union
import numpy as np
import PIL
import torch
import diffusers
......@@ -39,9 +40,28 @@ class PipelineLatentTesterMixin:
"`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results"
)
@property
def image_latents_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `image_latents_params` in the child test class. "
"`image_latents_params` are tested for if passing latents directly are producing same results"
)
def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
inputs = self.get_dummy_inputs(device, seed)
def convert_to_pt(image):
if isinstance(image, torch.Tensor):
input_image = image
elif isinstance(image, np.ndarray):
input_image = VaeImageProcessor.numpy_to_pt(image)
elif isinstance(image, PIL.Image.Image):
input_image = VaeImageProcessor.pil_to_numpy(image)
input_image = VaeImageProcessor.numpy_to_pt(input_image)
else:
raise ValueError(f"unsupported input_image_type {type(image)}")
return input_image
def convert_pt_to_type(image, input_image_type):
if input_image_type == "pt":
input_image = image
......@@ -56,21 +76,32 @@ class PipelineLatentTesterMixin:
for image_param in self.image_params:
if image_param in inputs.keys():
inputs[image_param] = convert_pt_to_type(inputs[image_param], input_image_type)
inputs[image_param] = convert_pt_to_type(
convert_to_pt(inputs[image_param]).to(device), input_image_type
)
inputs["output_type"] = output_type
return inputs
def test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4):
self._test_pt_np_pil_outputs_equivalent(expected_max_diff=expected_max_diff)
def _test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4, input_image_type="pt"):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
output_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pt"))[0]
output_np = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="np"))[0]
output_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pil"))[0]
output_pt = pipe(
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pt")
)[0]
output_np = pipe(
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="np")
)[0]
output_pil = pipe(
**self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pil")
)[0]
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
self.assertLess(
......@@ -98,6 +129,31 @@ class PipelineLatentTesterMixin:
max_diff = np.abs(out_input_pil - out_input_np).max()
self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`")
def test_latents_input(self):
if len(self.image_latents_params) == 0:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
vae = components["vae"]
inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt")
generator = inputs["generator"]
for image_param in self.image_latents_params:
if image_param in inputs.keys():
inputs[image_param] = (
vae.encode(inputs[image_param]).latent_dist.sample(generator) * vae.config.scaling_factor
)
out_latents_inputs = pipe(**inputs)[0]
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
@require_torch
class PipelineTesterMixin:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment