"docs/source/index.mdx" did not exist on "eb1c331c843cd16ad3c5444fcb0a0ddafc87febe"
Unverified Commit e6df8eda authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] attempt at fixing onetrainer lora. (#8242)

* attempt at fixing onetrainer lora.

* fix
parent 80cfaeba
...@@ -226,6 +226,8 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -226,6 +226,8 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
if "self_attn" in diffusers_name: if "self_attn" in diffusers_name:
if lora_name.startswith(("lora_te_", "lora_te1_")): if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[diffusers_name] = state_dict.pop(key) te_state_dict[diffusers_name] = state_dict.pop(key)
...@@ -243,6 +245,10 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -243,6 +245,10 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
else: else:
te2_state_dict[diffusers_name] = state_dict.pop(key) te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# OneTrainer specificity
elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
dora_scale_key_to_replace_te = ( dora_scale_key_to_replace_te = (
...@@ -270,7 +276,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -270,7 +276,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
network_alphas.update({new_name: alpha}) network_alphas.update({new_name: alpha})
if len(state_dict) > 0: if len(state_dict) > 0:
raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}") raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Kohya-style checkpoint detected.") logger.info("Kohya-style checkpoint detected.")
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
......
...@@ -62,6 +62,8 @@ DIFFUSERS_TO_PEFT = { ...@@ -62,6 +62,8 @@ DIFFUSERS_TO_PEFT = {
".out_proj.lora_linear_layer.down": ".out_proj.lora_A", ".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
".lora_linear_layer.up": ".lora_B", ".lora_linear_layer.up": ".lora_B",
".lora_linear_layer.down": ".lora_A", ".lora_linear_layer.down": ".lora_A",
"text_projection.lora.down.weight": "text_projection.lora_A.weight",
"text_projection.lora.up.weight": "text_projection.lora_B.weight",
} }
DIFFUSERS_OLD_TO_PEFT = { DIFFUSERS_OLD_TO_PEFT = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment