Unverified Commit 26a7851e authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

Add B-Lora training option to the advanced dreambooth lora script (#7741)



* add blora

* add blora

* add blora

* add blora

* little changes

* little changes

* remove redundancies

* fixes

* add B LoRA to readme

* style

* inference

* defaults + path to loras+ generation

* minor changes

* style

* minor changes

* minor changes

* blora arg

* added --lora_unet_blocks

* style

* Update examples/advanced_diffusion_training/README.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* add commit hash to B-LoRA repo cloneing

* change inference, remove cloning

* change inference, remove cloning
add section about configureable unet blocks

* change inference, remove cloning
add section about configureable unet blocks

* Apply suggestions from code review

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 3fd31eef
......@@ -234,7 +234,7 @@ In ComfyUI we will load a LoRA and a textual embedding at the same time.
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
### DoRA training
The advanced script now supports DoRA training too!
The advanced script supports DoRA training too!
> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.
......@@ -304,6 +304,147 @@ accelerate launch train_dreambooth_lora_sdxl_advanced.py \
> [!CAUTION]
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
### B-LoRA training
The advanced script now supports B-LoRA training too!
> Proposed in [Implicit Style-Content Separation using B-LoRA](https://arxiv.org/abs/2403.14572),
B-LoRA is a method that leverages LoRA to implicitly separate the style and content components of a **single** image.
It was shown that learning the LoRA weights of two specific blocks (referred to as B-LoRAs)
achieves style-content separation that cannot be achieved by training each B-LoRA independently.
Once trained, the two B-LoRAs can be used as independent components to allow various image stylization tasks
**Usage**
Enable B-LoRA training by adding this flag
```bash
--use_blora
```
You can train a B-LoRA with as little as 1 image, and 1000 steps. Try this default configuration as a start:
```bash
!accelerate launch train_dreambooth_b-lora_sdxl.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
--instance_data_dir="linoyts/B-LoRA_teddy_bear" \
--output_dir="B-LoRA_teddy_bear" \
--instance_prompt="a [v18]" \
--resolution=1024 \
--rank=64 \
--train_batch_size=1 \
--learning_rate=5e-5 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--gradient_checkpointing \
--mixed_precision="fp16"
```
**Inference**
The inference is a bit different:
1. we need load *specific* unet layers (as opposed to a regular LoRA/DoRA)
2. the trained layers we load, changes based on our objective (e.g. style/content)
```python
import torch
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
# taken & modified from B-LoRA repo - https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py
def is_belong_to_blocks(key, blocks):
try:
for g in blocks:
if g in key:
return True
return False
except Exception as e:
raise type(e)(f'failed to is_belong_to_block, due to: {e}')
def lora_lora_unet_blocks(lora_path, alpha, target_blocks):
state_dict, _ = pipeline.lora_state_dict(lora_path)
filtered_state_dict = {k: v * alpha for k, v in state_dict.items() if is_belong_to_blocks(k, target_blocks)}
return filtered_state_dict
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16,
).to("cuda")
# pick a blora for content/style (you can also set one to None)
content_B_lora_path = "lora-library/B-LoRA-teddybear"
style_B_lora_path= "lora-library/B-LoRA-pen_sketch"
content_B_LoRA = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=["unet.up_blocks.0.attentions.0"])
style_B_LoRA = lora_lora_unet_blocks(style_B_lora_path,alpha=1.1,target_blocks=["unet.up_blocks.0.attentions.1"])
combined_lora = {**content_B_LoRA, **style_B_LoRA}
# Load both loras
pipeline.load_lora_into_unet(combined_lora, None, pipeline.unet)
#generate
prompt = "a [v18] in [v30] style"
pipeline(prompt, num_images_per_prompt=4).images
```
### LoRA training of Targeted U-net Blocks
The advanced script now supports custom choice of U-net blocks to train during Dreambooth LoRA tuning.
> [!NOTE]
> This feature is still experimental
> Recently, works like B-LoRA showed the potential advantages of learning the LoRA weights of specific U-net blocks, not only in speed & memory,
> but also in reducing the amount of needed data, improving style manipulation and overcoming overfitting issues.
> In light of this, we're introducing a new feature to the advanced script to allow for configurable U-net learned blocks.
**Usage**
Configure LoRA learned U-net blocks adding a `lora_unet_blocks` flag, with a comma seperated string specifying the targeted blocks.
e.g:
```bash
--lora_unet_blocks="unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1"
```
> [!NOTE]
> if you specify both `--use_blora` and `--lora_unet_blocks`, values given in --lora_unet_blocks will be ignored.
> When enabling --use_blora, targeted U-net blocks are automatically set to be "unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1" as discussed in the paper.
> If you wish to experiment with different blocks, specify `--lora_unet_blocks` only.
**Inference**
Inference is the same as for B-LoRAs, except the input targeted blocks should be modified based on your training configuration.
```python
import torch
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
# taken & modified from B-LoRA repo - https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py
def is_belong_to_blocks(key, blocks):
try:
for g in blocks:
if g in key:
return True
return False
except Exception as e:
raise type(e)(f'failed to is_belong_to_block, due to: {e}')
def lora_lora_unet_blocks(lora_path, alpha, target_blocks):
state_dict, _ = pipeline.lora_state_dict(lora_path)
filtered_state_dict = {k: v * alpha for k, v in state_dict.items() if is_belong_to_blocks(k, target_blocks)}
return filtered_state_dict
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16,
).to("cuda")
lora_path = "lora-library/B-LoRA-pen_sketch"
state_dict = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=["unet.up_blocks.0.attentions.0"])
# Load traine dlora layers into the unet
pipeline.load_lora_into_unet(state_dict, None, pipeline.unet)
#generate
prompt = "a dog in [v30] style"
pipeline(prompt, num_images_per_prompt=4).images
```
### Tips and Tricks
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
......
......@@ -15,7 +15,6 @@
import argparse
import gc
import hashlib
import itertools
import json
import logging
......@@ -40,6 +39,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from huggingface_hub.utils import insecure_hashlib
from packaging import version
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
......@@ -696,6 +696,23 @@ def parse_args(input_args=None):
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
parser.add_argument(
"--lora_unet_blocks",
type=str,
default=None,
help=(
"the U-net blocks to tune during training. please specify them in a comma separated string, e.g. `unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1` etc."
"NOTE: By default (if not specified) - regular LoRA training is performed. "
"if --use_blora is enabled, this arg will be ignored, since in B-LoRA training, targeted U-net blocks are `unet.up_blocks.0.attentions.0` and `unet.up_blocks.0.attentions.1`"
),
)
parser.add_argument(
"--use_blora",
action="store_true",
help=(
"Whether to train a B-LoRA as proposed in- Implicit Style-Content Separation using B-LoRA https://arxiv.org/abs/2403.14572. "
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
......@@ -720,6 +737,11 @@ def parse_args(input_args=None):
"For full LoRA text encoder training check --train_text_encoder, for textual "
"inversion training check `--train_text_encoder_ti`"
)
if args.use_blora and args.lora_unet_blocks:
warnings.warn(
"You specified both `--use_blora` and `--lora_unet_blocks`, for B-LoRA training, target unet blocks are: `unet.up_blocks.0.attentions.0` and `unet.up_blocks.0.attentions.1`. "
"If you wish to target different U-net blocks, don't enable `--use_blora`"
)
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
......@@ -740,6 +762,40 @@ def parse_args(input_args=None):
return args
# Taken (and slightly modified) from B-LoRA repo https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py
def is_belong_to_blocks(key, blocks):
try:
for g in blocks:
if g in key:
return True
return False
except Exception as e:
raise type(e)(f"failed to is_belong_to_block, due to: {e}")
def get_unet_lora_target_modules(unet, use_blora, target_blocks=None):
if use_blora:
content_b_lora_blocks = "unet.up_blocks.0.attentions.0"
style_b_lora_blocks = "unet.up_blocks.0.attentions.1"
target_blocks = [content_b_lora_blocks, style_b_lora_blocks]
try:
blocks = [(".").join(blk.split(".")[1:]) for blk in target_blocks]
attns = [
attn_processor_name.rsplit(".", 1)[0]
for attn_processor_name, _ in unet.attn_processors.items()
if is_belong_to_blocks(attn_processor_name, blocks)
]
target_modules = [f"{attn}.{mat}" for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns]
return target_modules
except Exception as e:
raise type(e)(
f"failed to get_target_modules, due to: {e}. "
f"Please check the modules specified in --lora_unet_blocks are correct"
)
# Taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
class TokenEmbeddingsHandler:
def __init__(self, text_encoders, tokenizers):
......@@ -946,16 +1002,20 @@ class DreamBoothDataset(Dataset):
transforms.Normalize([0.5], [0.5]),
]
)
# if using B-LoRA for single image. do not use transformations
single_image = len(self.instance_images) < 2
for image in self.instance_images:
if not single_image:
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
self.original_sizes.append((image.height, image.width))
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
if not single_image and args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
if args.center_crop or single_image:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = train_crop(image)
......@@ -1216,7 +1276,7 @@ def main(args):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
......@@ -1374,12 +1434,24 @@ def main(args):
text_encoder_two.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers
if args.use_blora:
# if using B-LoRA, the targeted blocks to train are automatically set
target_modules = get_unet_lora_target_modules(unet, use_blora=True)
elif args.lora_unet_blocks:
# if training specific unet blocks not in the B-LoRA scheme
target_blocks_list = "".join(args.lora_unet_blocks.split()).split(",")
logger.info(f"list of unet blocks to train: {target_blocks_list}")
target_modules = get_unet_lora_target_modules(unet, use_blora=False, target_blocks=target_blocks_list)
else:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
target_modules=target_modules,
)
unet.add_adapter(unet_lora_config)
......@@ -1388,8 +1460,8 @@ def main(args):
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
......@@ -1505,6 +1577,7 @@ def main(args):
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook)
......@@ -1525,6 +1598,8 @@ def main(args):
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32)
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
......@@ -1780,7 +1855,12 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
tracker_name = (
"dreambooth-lora-sd-xl"
if "playground" not in args.pretrained_model_name_or_path
else "dreambooth-lora-playground"
)
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
......@@ -1833,7 +1913,6 @@ def main(args):
)
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# TODO: revisit other sampling algorithms
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
......@@ -1852,6 +1931,7 @@ def main(args):
# flag used for textual inversion
pivoted = False
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
# if performing any kind of optimization of text_encoder params
if args.train_text_encoder or args.train_text_encoder_ti:
if epoch == num_train_epochs_text_encoder:
......@@ -1869,7 +1949,6 @@ def main(args):
text_encoder_one.text_model.embeddings.requires_grad_(True)
text_encoder_two.text_model.embeddings.requires_grad_(True)
unet.train()
for step, batch in enumerate(train_dataloader):
if pivoted:
# stopping optimization of text_encoder params
......@@ -1970,7 +2049,8 @@ def main(args):
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
).sample
return_dict=False,
)[0]
else:
unet_added_conditions = {"time_ids": add_time_ids}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
......@@ -1988,7 +2068,8 @@ def main(args):
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
).sample
return_dict=False,
)[0]
weighting = None
if args.do_edm_style_training:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment