Unverified Commit 0b63ad5a authored by apolinário's avatar apolinário Committed by GitHub
Browse files

Create convert_diffusers_sdxl_lora_to_webui.py (#6395)



* Create convert_diffusers_sdxl_lora_to_webui.py

* Move some conversion logic to utils

* fix logging import

* Add usage example

---------
Co-authored-by: default avatarmultimodalart <joaopaulo.passos+multimodal@gmail.com>
parent 6a376cee
# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
# This means that you can input your diffusers-trained LoRAs and
# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
# and run the script:
# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
# now you can use corgy.safetensors in your WebUI of choice!
# To train your own, here are some diffusers training scripts and utils that you can use and then convert:
# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
# Canonical diffusers training scripts:
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
import argparse
import os
from safetensors.torch import load_file, save_file
from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
def convert_and_save(input_lora, output_lora=None):
if output_lora is None:
base_name = os.path.splitext(input_lora)[0]
output_lora = f"{base_name}_webui.safetensors"
diffusers_state_dict = load_file(input_lora)
peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, output_lora)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
parser.add_argument(
"input_lora",
type=str,
help="Path to the input LoRA model file in the diffusers format.",
)
parser.add_argument(
"output_lora",
type=str,
nargs="?",
help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
)
args = parser.parse_args()
convert_and_save(args.input_lora, args.output_lora)
......@@ -16,6 +16,11 @@ State dict utilities: utility methods for converting state dicts easily
"""
import enum
from .logging import get_logger
logger = get_logger(__name__)
class StateDictType(enum.Enum):
"""
......@@ -23,7 +28,7 @@ class StateDictType(enum.Enum):
"""
DIFFUSERS_OLD = "diffusers_old"
# KOHYA_SS = "kohya_ss" # TODO: implement this
KOHYA_SS = "kohya_ss"
PEFT = "peft"
DIFFUSERS = "diffusers"
......@@ -100,6 +105,14 @@ DIFFUSERS_OLD_TO_DIFFUSERS = {
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
}
PEFT_TO_KOHYA_SS = {
"lora_A": "lora_down",
"lora_B": "lora_up",
# This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys,
# adding prefixes and adding alpha values
# Check `convert_state_dict_to_kohya` for more
}
PEFT_STATE_DICT_MAPPINGS = {
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
......@@ -110,6 +123,8 @@ DIFFUSERS_STATE_DICT_MAPPINGS = {
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
}
KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS}
KEYS_TO_ALWAYS_REPLACE = {
".processor.": ".",
}
......@@ -228,3 +243,82 @@ def convert_unet_state_dict_to_peft(state_dict):
"""
mapping = UNET_TO_DIFFUSERS
return convert_state_dict(state_dict, mapping)
def convert_all_state_dict_to_peft(state_dict):
r"""
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
"""
try:
peft_dict = convert_state_dict_to_peft(state_dict)
except Exception as e:
if str(e) == "Could not automatically infer state dict type":
peft_dict = convert_unet_state_dict_to_peft(state_dict)
else:
raise
if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()):
raise ValueError("Your LoRA was not converted to PEFT")
return peft_dict
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
r"""
Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc.
The method only supports the conversion from PEFT to Kohya for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
but we add it here in case we don't want to rely on that method.
"""
try:
import torch
except ImportError:
logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.")
raise
peft_adapter_name = kwargs.pop("adapter_name", None)
if peft_adapter_name is not None:
peft_adapter_name = "." + peft_adapter_name
else:
peft_adapter_name = ""
if original_type is None:
if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
original_type = StateDictType.PEFT
if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
# Use the convert_state_dict function with the appropriate mapping
kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT])
kohya_ss_state_dict = {}
# Additional logic for replacing header, alpha parameters `.` with `_` in all keys
for kohya_key, weight in kohya_ss_partial_state_dict.items():
if "text_encoder_2." in kohya_key:
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
elif "text_encoder." in kohya_key:
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
elif "unet" in kohya_key:
kohya_key = kohya_key.replace("unet", "lora_unet")
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight
if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment