Unverified Commit c14057c8 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] improve lora support for flux. (#10810)

update lora support for flux.
parent 3579cd2b
......@@ -588,11 +588,13 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
all_unique_keys = {
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
for k in state_dict
if not k.startswith(("lora_unet_"))
}
all_unique_keys = sorted(all_unique_keys)
assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}"
assert all(k.startswith(("lora_transformer_", "lora_te1_")) for k in all_unique_keys), f"{all_unique_keys=}"
has_te_keys = False
for k in all_unique_keys:
if k.startswith("lora_transformer_single_transformer_blocks_"):
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
......@@ -600,6 +602,9 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
elif k.startswith("lora_transformer_transformer_blocks_"):
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"transformer_blocks.{i}"
elif k.startswith("lora_te1_"):
has_te_keys = True
continue
else:
raise NotImplementedError
......@@ -615,17 +620,57 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"
if diffusers_key == f"transformer_blocks.{i}":
print(k, diffusers_key)
_convert(k, diffusers_key, state_dict, new_state_dict)
if has_te_keys:
layer_pattern = re.compile(r"lora_te1_text_model_encoder_layers_(\d+)")
attn_mapping = {
"q_proj": ".self_attn.q_proj",
"k_proj": ".self_attn.k_proj",
"v_proj": ".self_attn.v_proj",
"out_proj": ".self_attn.out_proj",
}
mlp_mapping = {"fc1": ".mlp.fc1", "fc2": ".mlp.fc2"}
for k in all_unique_keys:
if not k.startswith("lora_te1_"):
continue
match = layer_pattern.search(k)
if not match:
continue
i = int(match.group(1))
diffusers_key = f"text_model.encoder.layers.{i}"
if "attn" in k:
for key_fragment, suffix in attn_mapping.items():
if key_fragment in k:
diffusers_key += suffix
break
elif "mlp" in k:
for key_fragment, suffix in mlp_mapping.items():
if key_fragment in k:
diffusers_key += suffix
break
_convert(k, diffusers_key, state_dict, new_state_dict)
if state_dict:
remaining_all_unet = all(k.startswith("lora_unet_") for k in state_dict)
if remaining_all_unet:
keys = list(state_dict.keys())
for k in keys:
state_dict.pop(k)
if len(state_dict) > 0:
raise ValueError(
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
)
new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()}
return new_state_dict
transformer_state_dict = {
f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
}
te_state_dict = {f"text_encoder.{k}": v for k, v in new_state_dict.items() if k.startswith("text_model.")}
return {**transformer_state_dict, **te_state_dict}
# This is weird.
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
......@@ -640,6 +685,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
)
if has_mixture:
return _convert_mixture_state_dict_to_diffusers(state_dict)
return _convert_sd_scripts_to_ai_toolkit(state_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment