Unverified Commit e550163b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Vae] Make sure all vae's work with latent diffusion models (#5880)

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more
parent 20f0cbc8
...@@ -17,7 +17,16 @@ from huggingface_hub import delete_repo ...@@ -17,7 +17,16 @@ from huggingface_hub import delete_repo
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
ConsistencyDecoderVAE,
DDIMScheduler,
DiffusionPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging from diffusers.utils import logging
...@@ -28,6 +37,12 @@ from diffusers.utils.testing_utils import ( ...@@ -28,6 +37,12 @@ from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from ..models.test_models_vae import (
get_asym_autoencoder_kl_config,
get_autoencoder_kl_config,
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
from ..others.test_utils import TOKEN, USER, is_staging_test from ..others.test_utils import TOKEN, USER, is_staging_test
...@@ -171,6 +186,34 @@ class PipelineLatentTesterMixin: ...@@ -171,6 +186,34 @@ class PipelineLatentTesterMixin:
max_diff = np.abs(out - out_latents_inputs).max() max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image") self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
def test_multi_vae(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
block_out_channels = pipe.vae.config.block_out_channels
norm_num_groups = pipe.vae.config.norm_num_groups
vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
configs = [
get_autoencoder_kl_config(block_out_channels, norm_num_groups),
get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
get_consistency_vae_config(block_out_channels, norm_num_groups),
get_autoencoder_tiny_config(block_out_channels),
]
out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
for vae_cls, config in zip(vae_classes, configs):
vae = vae_cls(**config)
vae = vae.to(torch_device)
components["vae"] = vae
vae_pipe = self.pipeline_class(**components)
out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
assert out_vae_np.shape == out_np.shape
@require_torch @require_torch
class PipelineKarrasSchedulerTesterMixin: class PipelineKarrasSchedulerTesterMixin:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment