Unverified Commit 65329aed authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[advanced dreambooth lora sdxl script] new features + bug fixes (#6691)



* add noise_offset param

* micro conditioning - wip

* image processing adjusted and moved to support micro conditioning

* change time ids to be computed inside train loop

* change time ids to be computed inside train loop

* change time ids to be computed inside train loop

* time ids shape fix

* move token replacement of validation prompt to the same section of instance prompt and class prompt

* add offset noise to sd15 advanced script

* fix token loading during validation

* fix token loading during validation in sdxl script

* a little clean

* style

* a little clean

* style

* sdxl script - a little clean + minor path fix

sd 1.5 script - change default resolution value

* ad 1.5 script - minor path fix

* fix missing comma in code example in model card

* clean up commented lines

* style

* remove time ids computed outside training loop - no longer used now that we utilize micro-conditioning, as all time ids are now computed inside the training loop

* style

* [WIP] - added draft readme, building off of examples/dreambooth/README.md

* readme

* readme

* readme

* readme

* readme

* readme

* readme

* readme

* removed --crops_coords_top_left from CLI args

* style

* fix missing shape bug due to missing RGB if statement

* add blog mention at the start of the reamde as well

* Update examples/advanced_diffusion_training/README.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* change note to render nicely as well

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 02338c93
# Advanced diffusion training examples
## Train Dreambooth LoRA with Stable Diffusion XL
> [!TIP]
> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
The `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
> [!NOTE]
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script)
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/advanced_diffusion_training` folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell e.g. a notebook
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Pivotal Tuning
**Training with text encoder(s)**
Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
available with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported.
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
We then optimize the newly-inserted token embeddings to represent the new concept.
To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
Please keep the following points in mind:
* SDXL has two text encoders. So, we fine-tune both using LoRA.
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry.
### 3D icon example
Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./3d_icon"
snapshot_download(
"LinoyTsaban/3d_icon",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
Let's review some of the advanced features we're going to be using for this example:
- **custom captions**:
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
```bash
pip install datasets
```
Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
```
--dataset_name=./3d_icon
--caption_column=prompt
```
You can also load a dataset straight from by specifying it's name in `dataset_name`.
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
- **pivotal tuning**
- **min SNR gamma**
**Now, we can launch training:**
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export DATASET_NAME="./3d_icon"
export OUTPUT_DIR="3d-icon-SDXL-LoRA"
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--pretrained_vae_model_name_or_path=$VAE_PATH \
--dataset_name=$DATASET_NAME \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir=$OUTPUT_DIR \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=3 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--train_text_encoder_ti_frac=0.5\
--snr_gamma=5.0 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
Our experiments were conducted on a single 40GB A100 GPU.
### Inference
Once training is done, we can perform inference like so:
1. starting with loading the unet lora weights
```python
import torch
from huggingface_hub import hf_hub_download, upload_file
from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL
from safetensors.torch import load_file
username = "linoyts"
repo_id = f"{username}/3d-icon-SDXL-LoRA"
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
```
2. now we load the pivotal tuning embeddings
```python
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-SDXL-LoRA_emb.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
```
3. let's generate images
```python
instance_token = "<s0><s1>"
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
image.save("llama.png")
```
### Comfy UI / AUTOMATIC1111 Inference
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
**AUTOMATIC1111 / SD.Next** \
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
**ComfyUI** \
In ComfyUI we will load a LoRA and a textual embedding at the same time.
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
-
### Specifying a better VAE
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
### Tips and Tricks
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
## Running on Colab Notebook
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_advanced_example.ipynb).
to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning)
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft==0.7.0
\ No newline at end of file
...@@ -119,10 +119,9 @@ def save_model_card( ...@@ -119,10 +119,9 @@ def save_model_card(
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
""" """
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model") diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
state_dict = load_file(embedding_path) state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
""" """
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**. webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
- Place it on it on your `embeddings` folder - Place it on it on your `embeddings` folder
...@@ -389,7 +388,7 @@ def parse_args(input_args=None): ...@@ -389,7 +388,7 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--resolution", "--resolution",
type=int, type=int,
default=1024, default=512,
help=( help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this" "The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution" " resolution"
...@@ -645,6 +644,7 @@ def parse_args(input_args=None): ...@@ -645,6 +644,7 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument( parser.add_argument(
"--rank", "--rank",
type=int, type=int,
...@@ -745,10 +745,11 @@ class TokenEmbeddingsHandler: ...@@ -745,10 +745,11 @@ class TokenEmbeddingsHandler:
idx += 1 idx += 1
# copied from train_dreambooth_lora_sdxl_advanced.py
def save_embeddings(self, file_path: str): def save_embeddings(self, file_path: str):
assert self.train_ids is not None, "Initialize new tokens before saving embeddings." assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
tensors = {} tensors = {}
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"} idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders): for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
...@@ -1634,6 +1635,11 @@ def main(args): ...@@ -1634,6 +1635,11 @@ def main(args):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(model_input) noise = torch.randn_like(model_input)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
)
bsz = model_input.shape[0] bsz = model_input.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( timesteps = torch.randint(
...@@ -1788,6 +1794,7 @@ def main(args): ...@@ -1788,6 +1794,7 @@ def main(args):
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
tokenizer=tokenizer_one,
text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder=accelerator.unwrap_model(text_encoder_one),
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
...@@ -1860,6 +1867,11 @@ def main(args): ...@@ -1860,6 +1867,11 @@ def main(args):
unet_lora_layers=unet_lora_layers, unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers,
) )
if args.train_text_encoder_ti:
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
embedding_handler.save_embeddings(embeddings_path)
images = [] images = []
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
# Final inference # Final inference
...@@ -1895,6 +1907,18 @@ def main(args): ...@@ -1895,6 +1907,18 @@ def main(args):
# load attention processors # load attention processors
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
# load new tokens
if args.train_text_encoder_ti:
state_dict = load_file(embeddings_path)
all_new_tokens = []
for key, value in token_abstraction_dict.items():
all_new_tokens.extend(value)
pipeline.load_textual_inversion(
state_dict["clip_l"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
# run inference # run inference
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
...@@ -1917,11 +1941,6 @@ def main(args): ...@@ -1917,11 +1941,6 @@ def main(args):
} }
) )
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
)
# Conver to WebUI format # Conver to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
......
...@@ -20,6 +20,7 @@ import itertools ...@@ -20,6 +20,7 @@ import itertools
import logging import logging
import math import math
import os import os
import random
import re import re
import shutil import shutil
import warnings import warnings
...@@ -45,6 +46,7 @@ from PIL.ImageOps import exif_transpose ...@@ -45,6 +46,7 @@ from PIL.ImageOps import exif_transpose
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig from transformers import AutoTokenizer, PretrainedConfig
...@@ -121,7 +123,7 @@ def save_model_card( ...@@ -121,7 +123,7 @@ def save_model_card(
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
""" """
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model") diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
state_dict = load_file(embedding_path) state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
...@@ -397,18 +399,6 @@ def parse_args(input_args=None): ...@@ -397,18 +399,6 @@ def parse_args(input_args=None):
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--crops_coords_top_left_h",
type=int,
default=0,
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
)
parser.add_argument(
"--crops_coords_top_left_w",
type=int,
default=0,
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
)
parser.add_argument( parser.add_argument(
"--center_crop", "--center_crop",
default=False, default=False,
...@@ -418,6 +408,11 @@ def parse_args(input_args=None): ...@@ -418,6 +408,11 @@ def parse_args(input_args=None):
" cropped. The images will be resized to the resolution first before cropping." " cropped. The images will be resized to the resolution first before cropping."
), ),
) )
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument( parser.add_argument(
"--train_text_encoder", "--train_text_encoder",
action="store_true", action="store_true",
...@@ -659,6 +654,7 @@ def parse_args(input_args=None): ...@@ -659,6 +654,7 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument( parser.add_argument(
"--rank", "--rank",
type=int, type=int,
...@@ -901,6 +897,41 @@ class DreamBoothDataset(Dataset): ...@@ -901,6 +897,41 @@ class DreamBoothDataset(Dataset):
self.instance_images = [] self.instance_images = []
for img in instance_images: for img in instance_images:
self.instance_images.extend(itertools.repeat(img, repeats)) self.instance_images.extend(itertools.repeat(img, repeats))
# image processing to prepare for using SD-XL micro-conditioning
self.original_sizes = []
self.crop_top_lefts = []
self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
for image in self.instance_images:
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
self.original_sizes.append((image.height, image.width))
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
crop_top_left = (y1, x1)
self.crop_top_lefts.append(crop_top_left)
image = train_transforms(image)
self.pixel_values.append(image)
self.num_instance_images = len(self.instance_images) self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images self._length = self.num_instance_images
...@@ -930,12 +961,12 @@ class DreamBoothDataset(Dataset): ...@@ -930,12 +961,12 @@ class DreamBoothDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
example = {} example = {}
instance_image = self.instance_images[index % self.num_instance_images] instance_image = self.pixel_values[index % self.num_instance_images]
instance_image = exif_transpose(instance_image) original_size = self.original_sizes[index % self.num_instance_images]
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
if not instance_image.mode == "RGB": example["instance_images"] = instance_image
instance_image = instance_image.convert("RGB") example["original_size"] = original_size
example["instance_images"] = self.image_transforms(instance_image) example["crop_top_left"] = crop_top_left
if self.custom_instance_prompts: if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images] caption = self.custom_instance_prompts[index % self.num_instance_images]
...@@ -966,6 +997,8 @@ class DreamBoothDataset(Dataset): ...@@ -966,6 +997,8 @@ class DreamBoothDataset(Dataset):
def collate_fn(examples, with_prior_preservation=False): def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples] pixel_values = [example["instance_images"] for example in examples]
prompts = [example["instance_prompt"] for example in examples] prompts = [example["instance_prompt"] for example in examples]
original_sizes = [example["original_size"] for example in examples]
crop_top_lefts = [example["crop_top_left"] for example in examples]
# Concat class and instance examples for prior preservation. # Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes. # We do this to avoid doing two forward passes.
...@@ -976,7 +1009,12 @@ def collate_fn(examples, with_prior_preservation=False): ...@@ -976,7 +1009,12 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = torch.stack(pixel_values) pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values, "prompts": prompts} batch = {
"pixel_values": pixel_values,
"prompts": prompts,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
}
return batch return batch
...@@ -1198,7 +1236,9 @@ def main(args): ...@@ -1198,7 +1236,9 @@ def main(args):
args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
if args.with_prior_preservation: if args.with_prior_preservation:
args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))
if args.validation_prompt:
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt)
# initialize the new tokens for textual inversion # initialize the new tokens for textual inversion
embedding_handler = TokenEmbeddingsHandler( embedding_handler = TokenEmbeddingsHandler(
[text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two] [text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two]
...@@ -1539,11 +1579,11 @@ def main(args): ...@@ -1539,11 +1579,11 @@ def main(args):
# pooled text embeddings # pooled text embeddings
# time ids # time ids
def compute_time_ids(): def compute_time_ids(crops_coords_top_left, original_size=None):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
if original_size is None:
original_size = (args.resolution, args.resolution) original_size = (args.resolution, args.resolution)
target_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution)
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids]) add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
...@@ -1560,9 +1600,6 @@ def main(args): ...@@ -1560,9 +1600,6 @@ def main(args):
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
# Handle instance prompt.
instance_time_ids = compute_time_ids()
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding. # the redundant encoding.
...@@ -1573,7 +1610,6 @@ def main(args): ...@@ -1573,7 +1610,6 @@ def main(args):
# Handle class prompt for prior-preservation. # Handle class prompt for prior-preservation.
if args.with_prior_preservation: if args.with_prior_preservation:
class_time_ids = compute_time_ids()
if freeze_text_encoder: if freeze_text_encoder:
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
args.class_prompt, text_encoders, tokenizers args.class_prompt, text_encoders, tokenizers
...@@ -1588,9 +1624,6 @@ def main(args): ...@@ -1588,9 +1624,6 @@ def main(args):
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't # pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader. # have to pass them to the dataloader.
add_time_ids = instance_time_ids
if args.with_prior_preservation:
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
# if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion # if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
add_special_tokens = True if args.train_text_encoder_ti else False add_special_tokens = True if args.train_text_encoder_ti else False
...@@ -1613,12 +1646,6 @@ def main(args): ...@@ -1613,12 +1646,6 @@ def main(args):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
if args.train_text_encoder_ti and args.validation_prompt:
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt)
if args.cache_latents: if args.cache_latents:
latents_cache = [] latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"): for batch in tqdm(train_dataloader, desc="Caching latents"):
...@@ -1778,6 +1805,12 @@ def main(args): ...@@ -1778,6 +1805,12 @@ def main(args):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(model_input) noise = torch.randn_like(model_input)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
)
bsz = model_input.shape[0] bsz = model_input.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( timesteps = torch.randint(
...@@ -1789,19 +1822,26 @@ def main(args): ...@@ -1789,19 +1822,26 @@ def main(args):
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# time ids
add_time_ids = torch.cat(
[
compute_time_ids(original_size=s, crops_coords_top_left=c)
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
]
)
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions. # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
if not train_dataset.custom_instance_prompts: if not train_dataset.custom_instance_prompts:
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
else: else:
elems_to_repeat_text_embeds = 1 elems_to_repeat_text_embeds = 1
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
# Predict the noise residual # Predict the noise residual
if freeze_text_encoder: if freeze_text_encoder:
unet_added_conditions = { unet_added_conditions = {
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), "time_ids": add_time_ids,
# "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
} }
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
...@@ -1812,7 +1852,7 @@ def main(args): ...@@ -1812,7 +1852,7 @@ def main(args):
added_cond_kwargs=unet_added_conditions, added_cond_kwargs=unet_added_conditions,
).sample ).sample
else: else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} unet_added_conditions = {"time_ids": add_time_ids}
prompt_embeds, pooled_prompt_embeds = encode_prompt( prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two], text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None, tokenizers=None,
...@@ -1954,6 +1994,8 @@ def main(args): ...@@ -1954,6 +1994,8 @@ def main(args):
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
tokenizer=tokenizer_one,
tokenizer_2=tokenizer_two,
text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two), text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
...@@ -2033,6 +2075,11 @@ def main(args): ...@@ -2033,6 +2075,11 @@ def main(args):
text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers,
) )
if args.train_text_encoder_ti:
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
embedding_handler.save_embeddings(embeddings_path)
images = [] images = []
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
# Final inference # Final inference
...@@ -2068,6 +2115,25 @@ def main(args): ...@@ -2068,6 +2115,25 @@ def main(args):
# load attention processors # load attention processors
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
# load new tokens
if args.train_text_encoder_ti:
state_dict = load_file(embeddings_path)
all_new_tokens = []
for key, value in token_abstraction_dict.items():
all_new_tokens.extend(value)
pipeline.load_textual_inversion(
state_dict["clip_l"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
pipeline.load_textual_inversion(
state_dict["clip_g"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer_2,
)
# run inference # run inference
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
...@@ -2090,11 +2156,6 @@ def main(args): ...@@ -2090,11 +2156,6 @@ def main(args):
} }
) )
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
)
# Conver to WebUI format # Conver to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment