"doc/git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "98dbdcd66c60116b0fbbf201692c860014521d1e"
Unverified Commit dd9a5caf authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] support for tiny autoencoder in img2img (#5636)



* support for tiny autoencoder in img2img
Co-authored-by: default avatarslep0v <37597789+slep0v@users.noreply.github.com>

* copy fix

* line space

* line space

* clean up

* spit out expected value

* spit out expected value

* assertion values.

* assertion values.

---------
Co-authored-by: default avatarslep0v <37597789+slep0v@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent a35e72b0
...@@ -75,6 +75,16 @@ EXAMPLE_DOC_STRING = """ ...@@ -75,6 +75,16 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
...@@ -561,11 +571,12 @@ class AltDiffusionImg2ImgPipeline( ...@@ -561,11 +571,12 @@ class AltDiffusionImg2ImgPipeline(
elif isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents init_latents = self.vae.config.scaling_factor * init_latents
......
...@@ -91,6 +91,16 @@ EXAMPLE_DOC_STRING = """ ...@@ -91,6 +91,16 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def prepare_image(image): def prepare_image(image):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
# Batch single image # Batch single image
...@@ -733,11 +743,12 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -733,11 +743,12 @@ class StableDiffusionControlNetImg2ImgPipeline(
elif isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents init_latents = self.vae.config.scaling_factor * init_latents
......
...@@ -103,6 +103,16 @@ EXAMPLE_DOC_STRING = """ ...@@ -103,6 +103,16 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
""" """
...@@ -949,12 +959,12 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -949,12 +959,12 @@ class StableDiffusionControlNetInpaintPipeline(
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list): if isinstance(generator, list):
image_latents = [ image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0]) for i in range(image.shape[0])
] ]
image_latents = torch.cat(image_latents, dim=0) image_latents = torch.cat(image_latents, dim=0)
else: else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents image_latents = self.vae.config.scaling_factor * image_latents
......
...@@ -131,6 +131,16 @@ EXAMPLE_DOC_STRING = """ ...@@ -131,6 +131,16 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionXLControlNetImg2ImgPipeline( class StableDiffusionXLControlNetImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin
): ):
...@@ -806,11 +816,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -806,11 +816,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
elif isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
if self.vae.config.force_upcast: if self.vae.config.force_upcast:
self.vae.to(dtype) self.vae.to(dtype)
......
...@@ -43,6 +43,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSaf ...@@ -43,6 +43,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSaf
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -426,11 +436,12 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -426,11 +436,12 @@ class LatentConsistencyModelImg2ImgPipeline(
elif isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents init_latents = self.vae.config.scaling_factor * init_latents
......
...@@ -34,6 +34,16 @@ from .image_encoder import PaintByExampleImageEncoder ...@@ -34,6 +34,16 @@ from .image_encoder import PaintByExampleImageEncoder
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def prepare_mask_and_masked_image(image, mask): def prepare_mask_and_masked_image(image, mask):
""" """
Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be
...@@ -334,12 +344,12 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -334,12 +344,12 @@ class PaintByExamplePipeline(DiffusionPipeline):
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list): if isinstance(generator, list):
image_latents = [ image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0]) for i in range(image.shape[0])
] ]
image_latents = torch.cat(image_latents, dim=0) image_latents = torch.cat(image_latents, dim=0)
else: else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents image_latents = self.vae.config.scaling_factor * image_latents
......
...@@ -36,6 +36,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput ...@@ -36,6 +36,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
...@@ -466,11 +476,12 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -466,11 +476,12 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
elif isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents init_latents = self.vae.config.scaling_factor * init_latents
......
...@@ -73,6 +73,15 @@ EXAMPLE_DOC_STRING = """ ...@@ -73,6 +73,15 @@ EXAMPLE_DOC_STRING = """
""" """
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def preprocess(image): def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
...@@ -555,11 +564,12 @@ class StableDiffusionImg2ImgPipeline( ...@@ -555,11 +564,12 @@ class StableDiffusionImg2ImgPipeline(
elif isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents init_latents = self.vae.config.scaling_factor * init_latents
......
...@@ -159,6 +159,16 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool ...@@ -159,6 +159,16 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image return mask, masked_image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionInpaintPipeline( class StableDiffusionInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
): ):
...@@ -654,12 +664,12 @@ class StableDiffusionInpaintPipeline( ...@@ -654,12 +664,12 @@ class StableDiffusionInpaintPipeline(
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list): if isinstance(generator, list):
image_latents = [ image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0]) for i in range(image.shape[0])
] ]
image_latents = torch.cat(image_latents, dim=0) image_latents = torch.cat(image_latents, dim=0)
else: else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents image_latents = self.vae.config.scaling_factor * image_latents
......
...@@ -92,6 +92,16 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -92,6 +92,16 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionXLImg2ImgPipeline( class StableDiffusionXLImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
): ):
...@@ -604,11 +614,12 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -604,11 +614,12 @@ class StableDiffusionXLImg2ImgPipeline(
elif isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
if self.vae.config.force_upcast: if self.vae.config.force_upcast:
self.vae.to(dtype) self.vae.to(dtype)
......
...@@ -238,6 +238,16 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool ...@@ -238,6 +238,16 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image return mask, masked_image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionXLInpaintPipeline( class StableDiffusionXLInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
): ):
...@@ -750,12 +760,12 @@ class StableDiffusionXLInpaintPipeline( ...@@ -750,12 +760,12 @@ class StableDiffusionXLInpaintPipeline(
if isinstance(generator, list): if isinstance(generator, list):
image_latents = [ image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0]) for i in range(image.shape[0])
] ]
image_latents = torch.cat(image_latents, dim=0) image_latents = torch.cat(image_latents, dim=0)
else: else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
if self.vae.config.force_upcast: if self.vae.config.force_upcast:
self.vae.to(dtype) self.vae.to(dtype)
......
...@@ -24,6 +24,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer ...@@ -24,6 +24,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
AutoencoderTiny,
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
HeunDiscreteScheduler, HeunDiscreteScheduler,
...@@ -148,6 +149,9 @@ class StableDiffusionImg2ImgPipelineFastTests( ...@@ -148,6 +149,9 @@ class StableDiffusionImg2ImgPipelineFastTests(
} }
return components return components
def get_dummy_tiny_autoencoder(self):
return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4)
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5 image = image / 2 + 0.5
...@@ -236,6 +240,23 @@ class StableDiffusionImg2ImgPipelineFastTests( ...@@ -236,6 +240,23 @@ class StableDiffusionImg2ImgPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img2img_tiny_autoencoder(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.vae = self.get_dummy_tiny_autoencoder()
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.00669, 0.00669, 0.0, 0.00693, 0.00858, 0.0, 0.00567, 0.00515, 0.00125])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@skip_mps @skip_mps
def test_save_load_local(self): def test_save_load_local(self):
return super().test_save_load_local() return super().test_save_load_local()
......
...@@ -22,6 +22,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProject ...@@ -22,6 +22,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProject
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
AutoencoderTiny,
EulerDiscreteScheduler, EulerDiscreteScheduler,
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel, UNet2DConditionModel,
...@@ -121,6 +122,9 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -121,6 +122,9 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
} }
return components return components
def get_dummy_tiny_autoencoder(self):
return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4)
def test_components_function(self): def test_components_function(self):
init_components = self.get_dummy_components() init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score") init_components.pop("requires_aesthetics_score")
...@@ -216,6 +220,23 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -216,6 +220,23 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
# make sure that it's equal # make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_stable_diffusion_xl_img2img_tiny_autoencoder(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
sd_pipe.vae = self.get_dummy_tiny_autoencoder()
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.0, 0.0, 0.0106, 0.0, 0.0, 0.0087, 0.0052, 0.0062, 0.0177])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
@require_torch_gpu @require_torch_gpu
def test_stable_diffusion_xl_offloads(self): def test_stable_diffusion_xl_offloads(self):
pipes = [] pipes = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment