Unverified Commit 1622265e authored by apolinário's avatar apolinário Committed by GitHub
Browse files

Add WebUI format support to Advanced Training Script (#6403)



* Add WebUI format support to Advanced Training Script

* style

---------
Co-authored-by: default avatarmultimodalart <joaopaulo.passos+multimodal@gmail.com>
parent 0b63ad5a
...@@ -20,6 +20,7 @@ import itertools ...@@ -20,6 +20,7 @@ import itertools
import logging import logging
import math import math
import os import os
import re
import shutil import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -41,7 +42,7 @@ from peft import LoraConfig ...@@ -41,7 +42,7 @@ from peft import LoraConfig
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
from PIL import Image from PIL import Image
from PIL.ImageOps import exif_transpose from PIL.ImageOps import exif_transpose
from safetensors.torch import save_file from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -58,7 +59,13 @@ from diffusers import ( ...@@ -58,7 +59,13 @@ from diffusers import (
from diffusers.loaders import LoraLoaderMixin from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils import (
check_min_version,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -93,10 +100,17 @@ def save_model_card( ...@@ -93,10 +100,17 @@ def save_model_card(
img_str += f""" img_str += f"""
- text: '{instance_prompt}' - text: '{instance_prompt}'
""" """
embeddings_filename = f"{repo_folder}_emb"
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
if instance_prompt_webui != embeddings_filename:
instance_prompt_sentence = f"For example, `{instance_prompt_webui}`"
else:
instance_prompt_sentence = ""
trigger_str = f"You should use {instance_prompt} to trigger the image generation." trigger_str = f"You should use {instance_prompt} to trigger the image generation."
diffusers_imports_pivotal = "" diffusers_imports_pivotal = ""
diffusers_example_pivotal = "" diffusers_example_pivotal = ""
webui_example_pivotal = ""
if train_text_encoder_ti: if train_text_encoder_ti:
trigger_str = ( trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier " "To trigger image generation of trained concept(or concepts) replace each concept identifier "
...@@ -105,11 +119,16 @@ def save_model_card( ...@@ -105,11 +119,16 @@ def save_model_card(
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
""" """
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model") diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
state_dict = load_file(embedding_path) state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
""" """
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
- Place it on it on your `embeddings` folder
- Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence}
(you need both the LoRA and the embeddings as they were trained together for this LoRA)
"""
if token_abstraction_dict: if token_abstraction_dict:
for key, value in token_abstraction_dict.items(): for key, value in token_abstraction_dict.items():
tokens = "".join(value) tokens = "".join(value)
...@@ -141,9 +160,14 @@ license: openrail++ ...@@ -141,9 +160,14 @@ license: openrail++
### These are {repo_id} LoRA adaption weights for {base_model}. ### These are {repo_id} LoRA adaption weights for {base_model}.
## Trigger words ## Download model
{trigger_str} ### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**.
- Place it on your `models/Lora` folder.
- On AUTOMATIC1111, load the LoRA by adding `<lora:{repo_folder}:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
{webui_example_pivotal}
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
...@@ -159,16 +183,12 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}' ...@@ -159,16 +183,12 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}'
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## Download model ## Trigger words
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder.
- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder.
All [Files & versions](/{repo_id}/tree/main). {trigger_str}
## Details ## Details
All [Files & versions](/{repo_id}/tree/main).
The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py). The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).
...@@ -2035,8 +2055,15 @@ def main(args): ...@@ -2035,8 +2055,15 @@ def main(args):
if args.train_text_encoder_ti: if args.train_text_encoder_ti:
embedding_handler.save_embeddings( embedding_handler.save_embeddings(
f"{args.output_dir}/embeddings.safetensors", f"{args.output_dir}/{args.output_dir}_emb.safetensors",
) )
# Conver to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
save_model_card( save_model_card(
model_id if not args.push_to_hub else repo_id, model_id if not args.push_to_hub else repo_id,
images=images, images=images,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment