Unverified Commit adbb0486 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] fix conversion utility so that lora dora loads correctly (#8688)

fix conversion utility so that lora dora loads correctly
parent effe4b97
...@@ -142,10 +142,10 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ ...@@ -142,10 +142,10 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
network_alphas = {} network_alphas = {}
# Check for DoRA-enabled LoRAs. # Check for DoRA-enabled LoRAs.
if any( dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
"dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k) dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
for k in state_dict dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
): if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
if is_peft_version("<", "0.9.0"): if is_peft_version("<", "0.9.0"):
raise ValueError( raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
...@@ -173,7 +173,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ ...@@ -173,7 +173,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Store DoRA scale if present. # Store DoRA scale if present.
if "dora_scale" in state_dict: if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[ unet_state_dict[
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
...@@ -192,7 +192,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ ...@@ -192,7 +192,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Store DoRA scale if present. # Store DoRA scale if present.
if "dora_scale" in state_dict: if dora_present_in_te or dora_present_in_te2:
dora_scale_key_to_replace_te = ( dora_scale_key_to_replace_te = (
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
) )
...@@ -214,7 +214,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ ...@@ -214,7 +214,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
if len(state_dict) > 0: if len(state_dict) > 0:
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Kohya-style checkpoint detected.") logger.info("Non-diffusers checkpoint detected.")
# Construct final state dict. # Construct final state dict.
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment