"tests/cpp/test_spmat_coo.cc" did not exist on "41349dcef5c62e59e774b6672af0deca517414bc"
Unverified Commit 288632ad authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Training utils] add kohya conversion dict. (#7435)

* add kohya conversion dict.

* update readme

* typo

* add filename
parent 5ce79cbd
...@@ -259,9 +259,9 @@ The authors found that by using DoRA, both the learning capacity and training st ...@@ -259,9 +259,9 @@ The authors found that by using DoRA, both the learning capacity and training st
> This is also aligned with some of the quantitative analysis shown in the paper. > This is also aligned with some of the quantitative analysis shown in the paper.
**Usage** **Usage**
1. To use DoRA you need to install `peft` from main: 1. To use DoRA you need to upgrade the installation of `peft`:
```bash ```bash
pip install git+https://github.com/huggingface/peft.git pip install-U peft
``` ```
2. Enable DoRA training by adding this flag 2. Enable DoRA training by adding this flag
```bash ```bash
...@@ -269,3 +269,7 @@ pip install git+https://github.com/huggingface/peft.git ...@@ -269,3 +269,7 @@ pip install git+https://github.com/huggingface/peft.git
``` ```
**Inference** **Inference**
The inference is the same as if you train a regular LoRA 🤗 The inference is the same as if you train a regular LoRA 🤗
## Format compatibility
You can pass `--output_kohya_format` to additionally generate a state dictionary which should be compatible with other platforms and tools such as Automatic 1111, Comfy, Kohya, etc. The `output_dir` will contain a file named "pytorch_lora_weights_kohya.safetensors".
\ No newline at end of file
...@@ -41,6 +41,7 @@ from peft import LoraConfig, set_peft_model_state_dict ...@@ -41,6 +41,7 @@ from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
from PIL import Image from PIL import Image
from PIL.ImageOps import exif_transpose from PIL.ImageOps import exif_transpose
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import crop from torchvision.transforms.functional import crop
...@@ -62,7 +63,9 @@ from diffusers.optimization import get_scheduler ...@@ -62,7 +63,9 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_unet_state_dict_to_peft, convert_unet_state_dict_to_peft,
is_wandb_available, is_wandb_available,
) )
...@@ -396,6 +399,11 @@ def parse_args(input_args=None): ...@@ -396,6 +399,11 @@ def parse_args(input_args=None):
default="lora-dreambooth-model", default="lora-dreambooth-model",
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
parser.add_argument(
"--output_kohya_format",
action="store_true",
help="Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument( parser.add_argument(
"--resolution", "--resolution",
...@@ -1890,6 +1898,11 @@ def main(args): ...@@ -1890,6 +1898,11 @@ def main(args):
text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers,
) )
if args.output_kohya_format:
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors")
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment