Unverified Commit 01ac37b3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] Clean Kohya conversion utils (#7374)

* clean up the kohya_conversion utility

* state dict assignment
parent 6a05b274
...@@ -198,27 +198,13 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -198,27 +198,13 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif lora_name.startswith("lora_te_"): elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
diffusers_name = key.replace("lora_te_", "").replace("_", ".") if lora_name.startswith(("lora_te_", "lora_te1_")):
diffusers_name = diffusers_name.replace("text.model", "text_model") key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
diffusers_name = diffusers_name.replace("self.attn", "self_attn") else:
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") key_to_replace = "lora_te2_"
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned. diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
elif lora_name.startswith("lora_te1_"):
diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model") diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn") diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
...@@ -226,31 +212,20 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -226,31 +212,20 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name: if "self_attn" in diffusers_name:
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[diffusers_name] = state_dict.pop(key) te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name: else:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te2_"):
diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te2_state_dict[diffusers_name] = state_dict.pop(key) te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name: elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might # Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet. # not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
te2_state_dict[diffusers_name] = state_dict.pop(key) te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment