Unverified Commit 01ee0978 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[advanced dreambooth lora sdxl] add DoRA training feature (#7072)



* add is_dora arg

* style

* add dora training feature to sd 1.5 script

* added notes about DoRA training

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 56b68459
......@@ -80,8 +80,7 @@ To do so, just specify `--train_text_encoder_ti` while launching training (for r
Please keep the following points in mind:
* SDXL has two text encoders. So, we fine-tune both using LoRA.
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry.
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
### 3D icon example
......@@ -234,6 +233,32 @@ In ComfyUI we will load a LoRA and a textual embedding at the same time.
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
### DoRA training
The advanced script now supports DoRA training too!
> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.
> [!NOTE]
> 💡DoRA training is still _experimental_
> and is likely to require different hyperparameter values to perform best compared to a LoRA.
> Specifically, we've noticed 2 differences to take into account your training:
> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)
> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.
> This is also aligned with some of the quantitative analysis shown in the paper.
**Usage**
1. To use DoRA you need to install `peft` from main:
```bash
pip install git+https://github.com/huggingface/peft.git
```
2. Enable DoRA training by adding this flag
```bash
--use_dora
```
**Inference**
The inference is the same as if you train a regular LoRA 🤗
### Tips and Tricks
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
......
......@@ -651,6 +651,16 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--use_dora",
type=bool,
action="store_true",
default=False,
help=(
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
......@@ -1219,6 +1229,7 @@ def main(args):
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
use_dora=args.use_dora,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
......@@ -1230,6 +1241,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
use_dora=args.use_dora,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
......
......@@ -661,6 +661,16 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--use_dora",
type=bool,
action="store_true",
default=False,
help=(
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
......@@ -1323,6 +1333,7 @@ def main(args):
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
use_dora=args.use_dora,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
......@@ -1334,6 +1345,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
use_dora=args.use_dora,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment