Unverified Commit b5814c55 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

add DoRA training feature to sdxl dreambooth lora script (#7235)

* dora in canonical script

* add mention of DoRA to readme
parent 99405736
...@@ -243,3 +243,29 @@ accelerate launch train_dreambooth_lora_sdxl.py \ ...@@ -243,3 +243,29 @@ accelerate launch train_dreambooth_lora_sdxl.py \
> [!CAUTION] > [!CAUTION]
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant". > Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
### DoRA training
The script now supports DoRA training too!
> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.
> [!NOTE]
> 💡DoRA training is still _experimental_
> and is likely to require different hyperparameter values to perform best compared to a LoRA.
> Specifically, we've noticed 2 differences to take into account your training:
> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)
> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.
> This is also aligned with some of the quantitative analysis shown in the paper.
**Usage**
1. To use DoRA you need to install `peft` from main:
```bash
pip install git+https://github.com/huggingface/peft.git
```
2. Enable DoRA training by adding this flag
```bash
--use_dora
```
**Inference**
The inference is the same as if you train a regular LoRA 🤗
\ No newline at end of file
...@@ -647,6 +647,15 @@ def parse_args(input_args=None): ...@@ -647,6 +647,15 @@ def parse_args(input_args=None):
default=4, default=4,
help=("The dimension of the LoRA update matrices."), help=("The dimension of the LoRA update matrices."),
) )
parser.add_argument(
"--use_dora",
action="store_true",
default=False,
help=(
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -1147,6 +1156,7 @@ def main(args): ...@@ -1147,6 +1156,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig( unet_lora_config = LoraConfig(
r=args.rank, r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank, lora_alpha=args.rank,
init_lora_weights="gaussian", init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"], target_modules=["to_k", "to_q", "to_v", "to_out.0"],
...@@ -1158,6 +1168,7 @@ def main(args): ...@@ -1158,6 +1168,7 @@ def main(args):
if args.train_text_encoder: if args.train_text_encoder:
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=args.rank, r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank, lora_alpha=args.rank,
init_lora_weights="gaussian", init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment