Unverified Commit 963ffca4 authored by Emmanuel Benazera's avatar Emmanuel Benazera Committed by GitHub
Browse files

fix: missing AutoencoderKL lora adapter (#9807)



* fix: missing AutoencoderKL lora adapter

* fix

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 30f2e9bd
...@@ -17,6 +17,7 @@ import torch ...@@ -17,6 +17,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import deprecate from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
...@@ -34,7 +35,7 @@ from ..modeling_utils import ModelMixin ...@@ -34,7 +35,7 @@ from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r""" r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
......
...@@ -36,7 +36,9 @@ from diffusers.utils.testing_utils import ( ...@@ -36,7 +36,9 @@ from diffusers.utils.testing_utils import (
backend_empty_cache, backend_empty_cache,
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
is_peft_available,
load_hf_numpy, load_hf_numpy,
require_peft_backend,
require_torch_accelerator, require_torch_accelerator,
require_torch_accelerator_with_fp16, require_torch_accelerator_with_fp16,
require_torch_gpu, require_torch_gpu,
...@@ -50,6 +52,10 @@ from diffusers.utils.torch_utils import randn_tensor ...@@ -50,6 +52,10 @@ from diffusers.utils.torch_utils import randn_tensor
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
if is_peft_available():
from peft import LoraConfig
enable_full_determinism() enable_full_determinism()
...@@ -263,6 +269,38 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -263,6 +269,38 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@require_peft_backend
def test_lora_adapter(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
vae = self.model_class(**init_dict)
target_modules_vae = [
"conv1",
"conv2",
"conv_in",
"conv_shortcut",
"conv",
"conv_out",
"skip_conv_1",
"skip_conv_2",
"skip_conv_3",
"skip_conv_4",
"to_k",
"to_q",
"to_v",
"to_out.0",
]
vae_lora_config = LoraConfig(
r=16,
init_lora_weights="gaussian",
target_modules=target_modules_vae,
)
vae.add_adapter(vae_lora_config, adapter_name="vae_lora")
active_lora = vae.active_adapters()
self.assertTrue(len(active_lora) == 1)
self.assertTrue(active_lora[0] == "vae_lora")
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL model_class = AsymmetricAutoencoderKL
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment