Unverified Commit a9c403c0 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] refactor lora conversion utility. (#8295)

* refactor lora conversion utility.

* remove error raises.

* add onetrainer support too.
parent e7b9a076
...@@ -43,7 +43,7 @@ from ..utils import ( ...@@ -43,7 +43,7 @@ from ..utils import (
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
if is_transformers_available(): if is_transformers_available():
...@@ -288,7 +288,7 @@ class LoraLoaderMixin: ...@@ -288,7 +288,7 @@ class LoraLoaderMixin:
if unet_config is not None: if unet_config is not None:
# use unet config to remap block numbers # use unet config to remap block numbers
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
return state_dict, network_alphas return state_dict, network_alphas
......
...@@ -123,45 +123,127 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b ...@@ -123,45 +123,127 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
return new_state_dict return new_state_dict
def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
"""
Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
Args:
state_dict (`dict`): The state dict to convert.
unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
"text_encoder".
Returns:
`tuple`: A tuple containing the converted state dict and a dictionary of alphas.
"""
unet_state_dict = {} unet_state_dict = {}
te_state_dict = {} te_state_dict = {}
te2_state_dict = {} te2_state_dict = {}
network_alphas = {} network_alphas = {}
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora: # Check for DoRA-enabled LoRAs.
if any(
"dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k)
for k in state_dict
):
if is_peft_version("<", "0.9.0"): if is_peft_version("<", "0.9.0"):
raise ValueError( raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
) )
# every down weight has a corresponding up weight and potentially an alpha weight # Iterate over all LoRA weights.
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] all_lora_keys = list(state_dict.keys())
for key in lora_keys: for key in all_lora_keys:
if not key.endswith("lora_down.weight"):
continue
# Extract LoRA name.
lora_name = key.split(".")[0] lora_name = key.split(".")[0]
# Find corresponding up weight and alpha.
lora_name_up = lora_name + ".lora_up.weight" lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha" lora_name_alpha = lora_name + ".alpha"
# Handle U-Net LoRAs.
if lora_name.startswith("lora_unet_"): if lora_name.startswith("lora_unet_"):
diffusers_name = _convert_unet_lora_key(key)
# Store down and up weights.
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Store DoRA scale if present.
if "dora_scale" in state_dict:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Handle text encoder LoRAs.
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
# Store down and up weights for te or te2.
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Store DoRA scale if present.
if "dora_scale" in state_dict:
dora_scale_key_to_replace_te = (
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
)
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith("lora_te2_"):
te2_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Store alpha if present.
if lora_name_alpha in state_dict:
alpha = state_dict.pop(lora_name_alpha).item()
network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
# Check if any keys remain.
if len(state_dict) > 0:
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Kohya-style checkpoint detected.")
# Construct final state dict.
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
te2_state_dict = (
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
if len(te2_state_dict) > 0
else None
)
if te2_state_dict is not None:
te_state_dict.update(te2_state_dict)
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alphas
def _convert_unet_lora_key(key):
"""
Converts a U-Net LoRA key to a Diffusers compatible key.
"""
diffusers_name = key.replace("lora_unet_", "").replace("_", ".") diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
if "input.blocks" in diffusers_name: # Replace common U-Net naming patterns.
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
else:
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
if "middle.block" in diffusers_name:
diffusers_name = diffusers_name.replace("middle.block", "mid_block") diffusers_name = diffusers_name.replace("middle.block", "mid_block")
else:
diffusers_name = diffusers_name.replace("mid.block", "mid_block") diffusers_name = diffusers_name.replace("mid.block", "mid_block")
if "output.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
else:
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
...@@ -171,7 +253,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -171,7 +253,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
diffusers_name = diffusers_name.replace("proj.out", "proj_out") diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
# SDXL specificity. # SDXL specific conversions.
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
pattern = r"\.\d+(?=\D*$)" pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1) diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
...@@ -184,36 +266,31 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -184,36 +266,31 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
if "skip" in diffusers_name: if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
# LyCORIS specificity. # LyCORIS specific conversions.
if "time.emb.proj" in diffusers_name: if "time.emb.proj" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name: if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
# General coverage. # General conversions.
if "transformer_blocks" in diffusers_name: if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "ff" in diffusers_name: elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = state_dict.pop(key) pass
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif any(key in diffusers_name for key in ("proj_in", "proj_out")): elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = state_dict.pop(key) pass
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else: else:
unet_state_dict[diffusers_name] = state_dict.pop(key) pass
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if is_unet_dora_lora: return diffusers_name
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
def _convert_text_encoder_lora_key(key, lora_name):
"""
Converts a text encoder LoRA key to a Diffusers compatible key.
"""
if lora_name.startswith(("lora_te_", "lora_te1_")): if lora_name.startswith(("lora_te_", "lora_te1_")):
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
else: else:
...@@ -228,44 +305,19 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -228,44 +305,19 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("text.projection", "text_projection") diffusers_name = diffusers_name.replace("text.projection", "text_projection")
if "self_attn" in diffusers_name: if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
if lora_name.startswith(("lora_te_", "lora_te1_")): pass
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name: elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might # Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet. # not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
if lora_name.startswith(("lora_te_", "lora_te1_")): return diffusers_name
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# OneTrainer specificity
elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
dora_scale_key_to_replace_te = (
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
)
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith("lora_te2_"):
te2_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Rename the alphas so that they can be mapped appropriately. def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
if lora_name_alpha in state_dict: """
alpha = state_dict.pop(lora_name_alpha).item() Gets the correct alpha name for the Diffusers model.
"""
if lora_name_alpha.startswith("lora_unet_"): if lora_name_alpha.startswith("lora_unet_"):
prefix = "unet." prefix = "unet."
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
...@@ -273,21 +325,4 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -273,21 +325,4 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
else: else:
prefix = "text_encoder_2." prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
network_alphas.update({new_name: alpha}) return {new_name: alpha}
if len(state_dict) > 0:
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Kohya-style checkpoint detected.")
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
te2_state_dict = (
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
if len(te2_state_dict) > 0
else None
)
if te2_state_dict is not None:
te_state_dict.update(te2_state_dict)
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alphas
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment