"graphbolt/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "b03d70d39a42a0587b3f53400cffb129dadf4a62"
Unverified Commit d486f0e8 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA serialization] fix: duplicate unet prefix problem. (#5991)



* fix: duplicate unet prefix problem.

* Update src/diffusers/loaders/lora.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 33512706
...@@ -391,6 +391,10 @@ class LoraLoaderMixin: ...@@ -391,6 +391,10 @@ class LoraLoaderMixin:
# their prefixes. # their prefixes.
keys = list(state_dict.keys()) keys = list(state_dict.keys())
if all(key.startswith("unet.unet") for key in keys):
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
deprecate("unet.unet keys", "0.27", deprecation_message)
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet. # Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.") logger.info(f"Loading {cls.unet_name}.")
...@@ -407,8 +411,9 @@ class LoraLoaderMixin: ...@@ -407,8 +411,9 @@ class LoraLoaderMixin:
else: else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only # Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix. # contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." if not USE_PEFT_BACKEND:
logger.warn(warn_message) warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message)
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0: if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
...@@ -800,29 +805,21 @@ class LoraLoaderMixin: ...@@ -800,29 +805,21 @@ class LoraLoaderMixin:
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
""" """
# Create a flat dictionary.
state_dict = {} state_dict = {}
# Populate the dictionary. def pack_weights(layers, prefix):
if unet_lora_layers is not None: layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
weights = ( layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers return layers_state_dict
)
unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()} if not (unet_lora_layers or text_encoder_lora_layers):
state_dict.update(unet_lora_state_dict) raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
if text_encoder_lora_layers is not None: if unet_lora_layers:
weights = ( state_dict.update(pack_weights(unet_lora_layers, "unet"))
text_encoder_lora_layers.state_dict()
if isinstance(text_encoder_lora_layers, torch.nn.Module)
else text_encoder_lora_layers
)
text_encoder_lora_state_dict = { if text_encoder_lora_layers:
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items() state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
}
state_dict.update(text_encoder_lora_state_dict)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
......
...@@ -67,7 +67,7 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: ...@@ -67,7 +67,7 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
current_lora_layer_sd = lora_layer.state_dict() current_lora_layer_sd = lora_layer.state_dict()
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items(): for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
# The matrix name can either be "down" or "up". # The matrix name can either be "down" or "up".
lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
return lora_state_dict return lora_state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment