Unverified Commit 43979c28 authored by Benjamin Bossan's avatar Benjamin Bossan Committed by GitHub
Browse files

TST Fix LoRA test that fails with PEFT >= 0.7.0 (#6216)



See #6185 for context.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 9ea6ac1b
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import importlib
import os
import tempfile
import time
......@@ -24,6 +25,7 @@ import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard
from packaging import version
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import (
......@@ -1983,10 +1985,26 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
fused_te_2_state_dict = pipe.text_encoder_2.state_dict()
unet_state_dict = pipe.unet.state_dict()
peft_ge_070 = version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0")
def remap_key(key, sd):
# some keys have moved around for PEFT >= 0.7.0, but they should still be loaded correctly
if (key in sd) or (not peft_ge_070):
return key
# instead of linear.weight, we now have linear.base_layer.weight, etc.
if key.endswith(".weight"):
key = key[:-7] + ".base_layer.weight"
elif key.endswith(".bias"):
key = key[:-5] + ".base_layer.bias"
return key
for key, value in text_encoder_1_sd.items():
key = remap_key(key, fused_te_state_dict)
self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
for key, value in text_encoder_2_sd.items():
key = remap_key(key, fused_te_2_state_dict)
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value))
for key, value in unet_state_dict.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment