Unverified Commit 5712c3d2 authored by ilisparrow's avatar ilisparrow Committed by GitHub
Browse files

[Core] enable lora for sdxl adapters too and add slow tests. (#5555)



* Enable lora for sdxl adapters too.

Issue #5516

* fix: assertion values.

* Use numpy_cosine_similarity_distance on the arrays
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Use numpy_cosine_similarity_distance on the arrays
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Changed imports orders to pass tests
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

---------
Co-authored-by: default avatarIlias A <iliasamri00@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 151998e1
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -1066,3 +1067,77 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1066,3 +1067,77 @@ class StableDiffusionXLAdapterPipeline(
return (image,) return (image,)
return StableDiffusionXLPipelineOutput(images=image) return StableDiffusionXLPipelineOutput(images=image)
# Overrride to properly handle the loading and unloading of the additional text encoder.
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)
@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
\ No newline at end of file
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import random import random
import gc
import unittest import unittest
import numpy as np import numpy as np
...@@ -29,10 +30,14 @@ from diffusers import ( ...@@ -29,10 +30,14 @@ from diffusers import (
StableDiffusionXLAdapterPipeline, StableDiffusionXLAdapterPipeline,
T2IAdapter, T2IAdapter,
UNet2DConditionModel, UNet2DConditionModel,
EulerAncestralDiscreteScheduler,
) )
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from diffusers.utils import load_image
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import ( from ..test_pipelines_common import (
PipelineTesterMixin, PipelineTesterMixin,
...@@ -560,3 +565,64 @@ class StableDiffusionXLMultiAdapterPipelineFastTests( ...@@ -560,3 +565,64 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
if test_mean_pixel_difference: if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0]) assert_mean_pixel_difference(output_batch[0][0], output[0][0])
@slow
@require_torch_gpu
class AdapterSDXLPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_canny(self):
adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16
).to("cpu")
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', adapter=adapter, torch_dtype=torch.float16, variant="fp16",
)
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
pipe.enable_sequential_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "toy"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
)
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
assert images[0].shape == (768, 512, 3)
original_image = images[0, -3:, -3:, -1].flatten()
assert numpy_cosine_similarity_distance(original_image, expected_image) < 1e-4
assert np.allclose(original_image, expected_image, atol=1e-04)
def test_canny_lora(self):
adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16
).to("cpu")
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', adapter=adapter, torch_dtype=torch.float16, variant="fp16",
)
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
pipe.enable_sequential_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "toy"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
)
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
assert images[0].shape == (768, 512, 3)
original_image = images[0, -3:, -3:, -1].flatten()
expected_image = np.array([0.50346327, 0.50708383, 0.50719553, 0.5135172, 0.5155377, 0.5066059, 0.49680984, 0.5005894, 0.48509413])
assert numpy_cosine_similarity_distance(original_image, expected_image) < 1e-4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment