Unverified Commit 8a692739 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FIX [`PEFT` / `Core`] Copy the state dict when passing it to `load_lora_weights` (#7058)

* copy the state dict in load lora weights

* fixup
parent 5aa31bd6
...@@ -106,6 +106,10 @@ class LoraLoaderMixin: ...@@ -106,6 +106,10 @@ class LoraLoaderMixin:
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
...@@ -1229,6 +1233,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -1229,6 +1233,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
# it here explicitly to be able to tell that it's coming from an SDXL # it here explicitly to be able to tell that it's coming from an SDXL
# pipeline. # pipeline.
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict( state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
......
...@@ -26,11 +26,13 @@ import torch.nn as nn ...@@ -26,11 +26,13 @@ import torch.nn as nn
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard from huggingface_hub.repocard import RepoCard
from packaging import version from packaging import version
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
AutoPipelineForImage2Image, AutoPipelineForImage2Image,
AutoPipelineForText2Image,
ControlNetModel, ControlNetModel,
DDIMScheduler, DDIMScheduler,
DiffusionPipeline, DiffusionPipeline,
...@@ -1745,6 +1747,40 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -1745,6 +1747,40 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
release_memory(pipe) release_memory(pipe)
def test_not_empty_state_dict(self):
# Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again
pipe = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors")
lcm_lora = load_file(cached_file)
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertTrue(lcm_lora != {})
release_memory(pipe)
def test_load_unload_load_state_dict(self):
# Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again
pipe = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors")
lcm_lora = load_file(cached_file)
previous_state_dict = lcm_lora.copy()
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertDictEqual(lcm_lora, previous_state_dict)
pipe.unload_lora_weights()
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertDictEqual(lcm_lora, previous_state_dict)
release_memory(pipe)
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment