Unverified Commit 9a7f8246 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Flux] Add advanced training script + support textual inversion inference (#9434)

* add ostris trainer to README & add cache latents of vae

* add ostris trainer to README & add cache latents of vae

* style

* readme

* add test for latent caching

* add ostris noise scheduler
https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95

* style

* fix import

* style

* fix tests

* style

* --change upcasting of transformer?

* update readme according to main

* add pivotal tuning for CLIP

* fix imports, encode_prompt call,add TextualInversionLoaderMixin to FluxPipeline for inference

* TextualInversionLoaderMixin support for FluxPipeline for inference

* move changes to advanced flux script, revert canonical

* add latent caching to canonical script

* revert changes to canonical script to keep it separate from https://github.com/huggingface/diffusers/pull/9160

* revert changes to canonical script to keep it separate from https://github.com/huggingface/diffusers/pull/9160

* style

* remove redundant line and change code block placement to align with logic

* add initializer_token arg

* add transformer frac for range support from pure textual inversion to the orig pivotal tuning

* support pure textual inversion - wip

* adjustments to support pure textual inversion and transformer optimization in only part of the epochs

* fix logic when using initializer token

* fix pure_textual_inversion_condition

* fix ti/pivotal loading of last validation run

* remove embeddings loading for ti in final training run (to avoid adding huggingface hub dependency)

* support pivotal for t5

* adapt pivotal for T5 encoder

* adapt pivotal for T5 encoder and support in flux pipeline

* t5 pivotal support + support fo pivotal for clip only or both

* fix param chaining

* fix param chaining

* README first draft

* readme

* readme

* readme

* style

* fix import

* style

* add fix from https://github.com/huggingface/diffusers/pull/9419



* add to readme, change function names

* te lr changes

* readme

* change concept tokens logic

* fix indices

* change arg name

* style

* dummy test

* revert dummy test

* reorder pivoting

* add warning in case the token abstraction is not the instance prompt

* experimental - wip - specific block training

* fix documentation and token abstraction processing

* remove transformer block specification feature (for now)

* style

* fix copies

* fix indexing issue when --initializer_concept has different amounts

* add if TextualInversionLoaderMixin to all flux pipelines

* style

* fix import

* fix imports

* address review comments - remove necessary prints & comments, use pin_memory=True, use free_memory utils, unify warning and prints

* style

* logger info fix

* make lora target modules configurable and change the default

* make lora target modules configurable and change the default

* style

* make lora target modules configurable and change the default, add notes to readme

* style

* add tests

* style

* fix repo id

* add updated requirements for advanced flux

* fix indices of t5 pivotal tuning embeddings

* fix path in test

* remove `pin_memory`

* fix filename of embedding

* fix filename of embedding

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent d9029f2c
# Advanced diffusion training examples
## Train Dreambooth LoRA with Flux.1 Dev
> [!TIP]
> 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script).
> As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like flux, stable diffusion given just a few(3~5) images of a subject.
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
The `train_dreambooth_lora_flux_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_flux.py`, with
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
[ostris](https://x.com/ostrisai):[ai-toolkit](https://github.com/ostris/ai-toolkit), [bghira](https://github.com/bghira):[SimpleTuner](https://github.com/bghira/SimpleTuner), [Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
> [!NOTE]
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/advanced_diffusion_training` folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell e.g. a notebook
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
### Pivotal Tuning (and more)
**Training with text encoder(s)**
Alongside the Transformer, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
available with `train_dreambooth_lora_flux_advanced.py`, in the advanced script **pivotal tuning** is also supported.
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
We then optimize the newly-inserted token embeddings to represent the new concept.
To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
Please keep the following points in mind:
* Flux uses two text encoders - [CLIP](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder) & [T5](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder_2) , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only.
To activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`.
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
* **pure textual inversion** - to support the full range from pivotal tuning to textual inversion we introduce `--train_transformer_frac` which controls the amount of epochs the transformer LoRA layers are trained. By default, `--train_transformer_frac==1`, to trigger a textual inversion run set `--train_transformer_frac==0`. Values between 0 and 1 are supported as well, and we welcome the community to experiment w/ different settings and share the results!
* **token initializer** - similar to the original textual inversion work, you can specify a token of your choosing as the starting point for training. By default, when enabling `--train_text_encoder_ti`, the new inserted tokens are initialized randomly. You can specify a token in `--initializer_token` such that the starting point for the trained embeddings will be the embeddings associated with your chosen `--initializer_token`.
## Training examples
Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./3d_icon"
snapshot_download(
"LinoyTsaban/3d_icon",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
Let's review some of the advanced features we're going to be using for this example:
- **custom captions**:
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
```bash
pip install datasets
```
Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
```
--dataset_name=./3d_icon
--caption_column=prompt
```
You can also load a dataset straight from by specifying it's name in `dataset_name`.
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
- **pivotal tuning**
### Example #1: Pivotal tuning
**Now, we can launch training:**
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export DATASET_NAME="./3d_icon"
export OUTPUT_DIR="3d-icon-Flux-LoRA"
accelerate launch train_dreambooth_lora_flux_advanced.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir=$OUTPUT_DIR \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--train_text_encoder_ti_frac=0.5\
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
Our experiments were conducted on a single 40GB A100 GPU.
### Example #2: Pivotal tuning with T5
Now let's try that with T5 as well, so instead of only optimizing the CLIP embeddings associated with newly inserted tokens, we'll optimize
the T5 embeddings as well. We can do this by simply adding `--enable_t5_ti` to the previous configuration:
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export DATASET_NAME="./3d_icon"
export OUTPUT_DIR="3d-icon-Flux-LoRA"
accelerate launch train_dreambooth_lora_flux_advanced.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir=$OUTPUT_DIR \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--enable_t5_ti\
--train_text_encoder_ti_frac=0.5\
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
### Example #3: Textual Inversion
To explore a pure textual inversion - i.e. only optimizing the text embeddings w/o training transformer LoRA layers, we
can set the value for `--train_transformer_frac` - which is responsible for the percent of epochs in which the transformer is
trained. By setting `--train_transformer_frac == 0` and enabling `--train_text_encoder_ti` we trigger a textual inversion train
run.
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export DATASET_NAME="./3d_icon"
export OUTPUT_DIR="3d-icon-Flux-LoRA"
accelerate launch train_dreambooth_lora_flux_advanced.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir=$OUTPUT_DIR \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--enable_t5_ti\
--train_text_encoder_ti_frac=0.5\
--train_transformer_frac=0\
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
### Inference - pivotal tuning
Once training is done, we can perform inference like so:
1. starting with loading the transformer lora weights
```python
import torch
from huggingface_hub import hf_hub_download, upload_file
from diffusers import AutoPipelineForText2Image
from safetensors.torch import load_file
username = "linoyts"
repo_id = f"{username}/3d-icon-Flux-LoRA"
pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
```
2. now we load the pivotal tuning embeddings
💡note that if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder
```python
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`
pipe.load_textual_inversion(state_dict["t5"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
```
3. let's generate images
```python
instance_token = "<s0><s1>"
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
image.save("llama.png")
```
### Inference - pure textual inversion
In this case, we don't load transformer layers as before, since we only optimize the text embeddings. The output of a textual inversion train run is a
`.safetensors` file containing the trained embeddings for the new tokens either for the CLIP encoder, or for both encoders (CLIP and T5)
1. starting with loading the embeddings.
💡note that here too, if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder
```python
import torch
from huggingface_hub import hf_hub_download, upload_file
from diffusers import AutoPipelineForText2Image
from safetensors.torch import load_file
username = "linoyts"
repo_id = f"{username}/3d-icon-Flux-LoRA"
pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`
pipe.load_textual_inversion(state_dict["t5"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
```
2. let's generate images
```python
instance_token = "<s0><s1>"
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
image.save("llama.png")
```
### Comfy UI / AUTOMATIC1111 Inference
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
**AUTOMATIC1111 / SD.Next** \
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
**ComfyUI** \
In ComfyUI we will load a LoRA and a textual embedding at the same time.
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
accelerate>=0.31.0
torchvision
transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft>=0.11.1
sentencepiece
\ No newline at end of file
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py"
def test_dreambooth_lora_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_text_encoder_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
starts_with_expected_prefix = all(
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_expected_prefix)
def test_dreambooth_lora_pivotal_tuning_flux_clip(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder_ti
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure embeddings were also saved
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# make sure the state_dict has the correct naming in the parameters.
textual_inversion_state_dict = safetensors.torch.load_file(
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
)
is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys())
self.assertTrue(is_clip)
# when performing pivotal tuning, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder_ti
--enable_t5_ti
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure embeddings were also saved
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# make sure the state_dict has the correct naming in the parameters.
textual_inversion_state_dict = safetensors.torch.load_file(
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
)
is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys())
self.assertTrue(is_te)
# when performing pivotal tuning, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import copy
import itertools
import logging
import math
import os
import random
import re
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from safetensors.torch import save_file
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
import diffusers
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxPipeline,
FluxTransformer2DModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
def save_model_card(
repo_id: str,
images=None,
base_model: str = None,
train_text_encoder=False,
train_text_encoder_ti=False,
enable_t5_ti=False,
pure_textual_inversion=False,
token_abstraction_dict=None,
instance_prompt=None,
validation_prompt=None,
repo_folder=None,
):
widget_dict = []
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
)
diffusers_load_lora = ""
diffusers_imports_pivotal = ""
diffusers_example_pivotal = ""
if not pure_textual_inversion:
diffusers_load_lora = (
f"""pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')"""
)
if train_text_encoder_ti:
embeddings_filename = f"{repo_folder}_emb"
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
"in you prompt with the new inserted tokens:\n"
)
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
if enable_t5_ti:
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(state_dict["t5"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
"""
else:
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
"""
if token_abstraction_dict:
for key, value in token_abstraction_dict.items():
tokens = "".join(value)
trigger_str += f"""
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""
model_description = f"""
# Flux DreamBooth LoRA - {repo_id}
<Gallery />
## Model description
These are {repo_id} DreamBooth LoRA weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).
Was LoRA for the text encoder enabled? {train_text_encoder}.
Pivotal tuning was enabled: {train_text_encoder_ti}.
## Trigger words
{trigger_str}
## Download model
[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch
{diffusers_imports_pivotal}
pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
{diffusers_load_lora}
{diffusers_example_pivotal}
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## License
Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="other",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-image",
"diffusers-training",
"diffusers",
"lora",
"flux",
"flux-diffusers",
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def load_text_encoders(class_one, class_two):
text_encoder_one = class_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = class_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
return text_encoder_one, text_encoder_two
def log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
autocast_ctx = nullcontext()
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
phase_name: [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
del pipeline
free_memory()
return images
def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel
return T5EncoderModel
else:
raise ValueError(f"{model_class} is not supported.")
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
help=("A folder containing the training data. "),
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--image_column",
type=str,
default="image",
help="The column of the dataset containing the target image. By "
"default, the standard Image Dataset maps out 'file_name' "
"to 'image'.",
)
parser.add_argument(
"--caption_column",
type=str,
default=None,
help="The column of the dataset containing the instance prompt for each image",
)
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
"--token_abstraction",
type=str,
default="TOK",
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. "
"'TOK,TOK2,TOK3' etc.",
)
parser.add_argument(
"--num_new_tokens_per_abstraction",
type=int,
default=None,
help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
"tokens - <si><si+1> ",
)
parser.add_argument(
"--initializer_concept",
type=str,
default=None,
help="the concept to use to initialize the new inserted tokens when training with "
"--train_text_encoder_ti = True. By default, new tokens (<si><si+1>) are initialized with random value. "
"Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. "
"--num_new_tokens_per_abstraction is ignored when initializer_concept is provided",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--max_sequence_length",
type=int,
default=512,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=50,
help=(
"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
parser.add_argument(
"--num_class_images",
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="flux-dreambooth-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
)
parser.add_argument(
"--train_text_encoder_ti",
action="store_true",
help=("Whether to use pivotal tuning / textual inversion"),
)
parser.add_argument(
"--enable_t5_ti",
action="store_true",
help=(
"Whether to use pivotal tuning / textual inversion for the T5 encoder as well (in addition to CLIP encoder)"
),
)
parser.add_argument(
"--train_text_encoder_ti_frac",
type=float,
default=0.5,
help=("The percentage of epochs to perform textual inversion"),
)
parser.add_argument(
"--train_text_encoder_frac",
type=float,
default=1.0,
help=("The percentage of epochs to perform text encoder tuning"),
)
parser.add_argument(
"--train_transformer_frac",
type=float,
default=1.0,
help=("The percentage of epochs to perform transformer tuning"),
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
help="the FLUX.1 dev variant is a guidance distilled model",
)
parser.add_argument(
"--text_encoder_lr",
type=float,
default=5e-6,
help="Text encoder learning rate to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--weighting_scheme",
type=str,
default="none",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--optimizer",
type=str,
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--prodigy_beta3",
type=float,
default=None,
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW",
)
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument(
"--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params"
)
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--lora_layers",
type=str,
default=None,
help=(
"The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. "
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
),
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
)
parser.add_argument(
"--prodigy_use_bias_correction",
type=bool,
default=True,
help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
)
parser.add_argument(
"--prodigy_safeguard_warmup",
type=bool,
default=True,
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
"Ignored if optimizer is adamW",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
default=None,
choices=["no", "fp32", "fp16", "bf16"],
help=(
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
if args.dataset_name is None and args.instance_data_dir is None:
raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
if args.dataset_name is not None and args.instance_data_dir is not None:
raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
if args.train_text_encoder and args.train_text_encoder_ti:
raise ValueError(
"Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. "
"For full LoRA text encoder training check --train_text_encoder, for textual "
"inversion training check `--train_text_encoder_ti`"
)
if args.train_transformer_frac < 1 and not args.train_text_encoder_ti:
raise ValueError(
"--train_transformer_frac must be == 1 if text_encoder training / textual inversion is not enabled."
)
if args.train_transformer_frac < 1 and args.train_text_encoder_ti_frac < 1:
raise ValueError(
"--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. "
"This contradicts with --max_train_steps, please specify different values or set both to 1."
)
if args.enable_t5_ti and not args.train_text_encoder_ti:
logger.warning("You need not use --enable_t5_ti without --train_text_encoder_ti.")
if args.train_text_encoder_ti and args.initializer_concept and args.num_new_tokens_per_abstraction:
logger.warning(
"When specifying --initializer_concept, the number of tokens per abstraction is detrimned "
"by the initializer token. --num_new_tokens_per_abstraction will be ignored"
)
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
if args.with_prior_preservation:
if args.class_data_dir is None:
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
else:
if args.class_data_dir is not None:
logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
if args.class_prompt is not None:
logger.warning("You need not use --class_prompt without --with_prior_preservation.")
return args
# Modified from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
class TokenEmbeddingsHandler:
def __init__(self, text_encoders, tokenizers):
self.text_encoders = text_encoders
self.tokenizers = tokenizers
self.train_ids: Optional[torch.Tensor] = None
self.train_ids_t5: Optional[torch.Tensor] = None
self.inserting_toks: Optional[List[str]] = None
self.embeddings_settings = {}
def initialize_new_tokens(self, inserting_toks: List[str]):
idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
assert all(
isinstance(tok, str) for tok in inserting_toks
), "All elements in inserting_toks should be strings."
self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
tokenizer.add_special_tokens(special_tokens_dict)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Convert the token abstractions to ids
if idx == 0:
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
else:
self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks)
# random initialization of new tokens
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
std_token_embedding = embeds.weight.data.std()
logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
# if initializer_concept are not provided, token embeddings are initialized randomly
if args.initializer_concept is None:
hidden_size = (
text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
)
embeds.weight.data[train_ids] = (
torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)
* std_token_embedding
)
else:
# Convert the initializer_token, placeholder_token to ids
initializer_token_ids = tokenizer.encode(args.initializer_concept, add_special_tokens=False)
for token_idx, token_id in enumerate(train_ids):
embeds.weight.data[token_id] = (embeds.weight.data)[
initializer_token_ids[token_idx % len(initializer_token_ids)]
].clone()
self.embeddings_settings[f"original_embeddings_{idx}"] = embeds.weight.data.clone()
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
# makes sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[train_ids] = False
self.embeddings_settings[f"index_no_updates_{idx}"] = index_no_updates
logger.info(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
idx += 1
def save_embeddings(self, file_path: str):
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
tensors = {}
# text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
for idx, text_encoder in enumerate(self.text_encoders):
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
new_token_embeddings = embeds.weight.data[train_ids]
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0),
# Note: When loading with diffusers, any name can work - simply specify in inference
tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
# tensors[f"text_encoders_{idx}"] = new_token_embeddings
save_file(tensors, file_path)
@property
def dtype(self):
return self.text_encoders[0].dtype
@property
def device(self):
return self.text_encoders[0].device
@torch.no_grad()
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
embeds.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
.to(device=text_encoder.device)
.to(dtype=text_encoder.dtype)
)
# for the parts that were updated, we need to normalize them
# to have the same std as before
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
index_updates = ~index_no_updates
new_embeddings = embeds.weight.data[index_updates]
off_ratio = std_token_embedding / new_embeddings.std()
new_embeddings = new_embeddings * (off_ratio**0.1)
embeds.weight.data[index_updates] = new_embeddings
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images.
"""
def __init__(
self,
instance_data_root,
instance_prompt,
class_prompt,
train_text_encoder_ti,
token_abstraction_dict=None, # token mapping for textual inversion
class_data_root=None,
class_num=None,
size=1024,
repeats=1,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.instance_prompt = instance_prompt
self.custom_instance_prompts = None
self.class_prompt = class_prompt
self.token_abstraction_dict = token_abstraction_dict
self.train_text_encoder_ti = train_text_encoder_ti
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if args.dataset_name is not None:
try:
from datasets import load_dataset
except ImportError:
raise ImportError(
"You are trying to load your data using the datasets library. If you wish to train using custom "
"captions please install the datasets library: `pip install datasets`. If you wish to load a "
"local folder containing images only, specify --instance_data_dir instead."
)
# Downloading and loading a dataset from the hub.
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
# Preprocessing the datasets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
if args.caption_column is None:
logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the "
"column as --caption_column"
)
self.custom_instance_prompts = None
else:
if args.caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
custom_instance_prompts = dataset["train"][args.caption_column]
# create final list of captions according to --repeats
self.custom_instance_prompts = []
for caption in custom_instance_prompts:
self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
else:
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
self.custom_instance_prompts = None
self.instance_images = []
for img in instance_images:
self.instance_images.extend(itertools.repeat(img, repeats))
self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
for image in self.instance_images:
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
image = train_transforms(image)
self.pixel_values.append(image)
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
if class_num is not None:
self.num_class_images = min(len(self.class_images_path), class_num)
else:
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = self.pixel_values[index % self.num_instance_images]
example["instance_images"] = instance_image
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
if caption:
if self.train_text_encoder_ti:
# replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in self.token_abstraction_dict.items():
caption = caption.replace(token_abs, "".join(token_replacement))
example["instance_prompt"] = caption
else:
example["instance_prompt"] = self.instance_prompt
else: # the given instance prompt is used for all images
example["instance_prompt"] = self.instance_prompt
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image)
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt"] = self.class_prompt
return example
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
pixel_values += [example["class_images"] for example in examples]
prompts += [example["class_prompt"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values, "prompts": prompts}
return batch
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=False):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
return text_input_ids
def _get_t5_prompt_embeds(
text_encoder,
tokenizer,
max_sequence_length=512,
prompt=None,
num_images_per_prompt=1,
device=None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def _get_clip_prompt_embeds(
text_encoder,
tokenizer,
prompt: str,
device=None,
text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
def encode_prompt(
text_encoders,
tokenizers,
prompt: str,
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _get_clip_prompt_embeds(
text_encoder=text_encoders[0],
tokenizer=tokenizers[0],
prompt=prompt,
device=device if device is not None else text_encoders[0].device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None,
)
prompt_embeds = _get_t5_prompt_embeds(
text_encoder=text_encoders[1],
tokenizer=tokenizers[1],
max_sequence_length=max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[1].device,
text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None,
)
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, pooled_prompt_embeds, text_ids
# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# generate the multiplier based on cosmap loss weighing
# this is only used on linear timesteps for now
# cosine map weighing is higher in the middle and lower at the ends
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
# cosmap_weighing = 2 / (math.pi * bot)
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
# clip at 1e4 (1e6 is too high)
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
# bring to a mean of 1
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")
self.linear_timesteps = timesteps
# self.linear_timesteps_weights = cosmap_weighing
self.linear_timesteps_weights = sigma_sqrt_weighing
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
# Get the weights for the timesteps
weights = self.linear_timesteps_weights[step_indices].flatten()
return weights
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
# Scale and reverse the values to go from 1000 to 0
timesteps = (1 - t) * 1000
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
self.timesteps = timesteps.to(device=device)
return timesteps
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline
free_memory()
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
model_id = args.hub_model_id or Path(args.output_dir).name
repo_id = None
if args.push_to_hub:
repo_id = create_repo(
repo_id=model_id,
exist_ok=True,
).repo_id
# Load the tokenizers
tokenizer_one = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
)
tokenizer_two = T5TokenizerFast.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
)
# import correct text encoder classes
text_encoder_cls_one = import_model_class_from_model_name_or_path(
args.pretrained_model_name_or_path, args.revision
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
)
# Load scheduler and models
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
transformer = FluxTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
)
if args.train_text_encoder_ti:
# we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
# TOK2" -> ["TOK", "TOK2"] etc.
token_abstraction_list = [place_holder.strip() for place_holder in re.split(r",\s*", args.token_abstraction)]
logger.info(f"list of token identifiers: {token_abstraction_list}")
if args.initializer_concept is None:
num_new_tokens_per_abstraction = (
2 if args.num_new_tokens_per_abstraction is None else args.num_new_tokens_per_abstraction
)
# if args.initializer_concept is provided, we ignore args.num_new_tokens_per_abstraction
else:
token_ids = tokenizer_one.encode(args.initializer_concept, add_special_tokens=False)
num_new_tokens_per_abstraction = len(token_ids)
if args.enable_t5_ti:
token_ids_t5 = tokenizer_two.encode(args.initializer_concept, add_special_tokens=False)
num_new_tokens_per_abstraction = max(len(token_ids), len(token_ids_t5))
logger.info(
f"initializer_concept: {args.initializer_concept}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}"
)
token_abstraction_dict = {}
token_idx = 0
for i, token in enumerate(token_abstraction_list):
token_abstraction_dict[token] = [f"<s{token_idx + i + j}>" for j in range(num_new_tokens_per_abstraction)]
token_idx += num_new_tokens_per_abstraction - 1
# replace instances of --token_abstraction in --instance_prompt with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in token_abstraction_dict.items():
new_instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
if args.instance_prompt == new_instance_prompt:
logger.warning(
"Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified "
"--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning"
)
args.instance_prompt = new_instance_prompt
if args.with_prior_preservation:
args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))
if args.validation_prompt:
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
# initialize the new tokens for textual inversion
text_encoders = [text_encoder_one, text_encoder_two] if args.enable_t5_ti else [text_encoder_one]
tokenizers = [tokenizer_one, tokenizer_two] if args.enable_t5_ti else [tokenizer_one]
embedding_handler = TokenEmbeddingsHandler(text_encoders, tokenizers)
inserting_toks = []
for new_tok in token_abstraction_dict.values():
inserting_toks.extend(new_tok)
embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks)
# We only train the additional adapter LoRA layers
transformer.requires_grad_(False)
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
vae.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = [
"attn.to_k",
"attn.to_q",
"attn.to_v",
"attn.to_out.0",
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"ff.net.0.proj",
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
]
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
FluxPipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
def load_model_hook(models, input_dir):
transformer_ = None
text_encoder_one_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [transformer_]
if args.train_text_encoder:
models.extend([text_encoder_one_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [transformer]
if args.train_text_encoder:
models.extend([text_encoder_one])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
# if we use textual inversion, we freeze all parameters except for the token embeddings
# in text encoder
elif args.train_text_encoder_ti:
text_lora_parameters_one = [] # CLIP
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
param.requires_grad = False
if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
param.requires_grad = False
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
# if --train_text_encoder_ti and train_transformer_frac == 0 where essentially performing textual inversion
# and not training transformer LoRA layers
pure_textual_inversion = args.train_text_encoder_ti and args.train_transformer_frac == 0
# Optimization parameters
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
if not freeze_text_encoder:
# different learning rate for text encoder and transformer
text_parameters_one_with_lr = {
"params": text_lora_parameters_one,
"weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"lr": args.text_encoder_lr,
}
if not args.enable_t5_ti:
# pure textual inversion - only clip
if pure_textual_inversion:
params_to_optimize = [
text_parameters_one_with_lr,
]
te_idx = 0
else: # regular te training or regular pivotal for clip
params_to_optimize = [
transformer_parameters_with_lr,
text_parameters_one_with_lr,
]
te_idx = 1
elif args.enable_t5_ti:
# pivotal tuning of clip & t5
text_parameters_two_with_lr = {
"params": text_lora_parameters_two,
"weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"lr": args.text_encoder_lr,
}
# pure textual inversion - only clip & t5
if pure_textual_inversion:
params_to_optimize = [text_parameters_one_with_lr, text_parameters_two_with_lr]
te_idx = 0
else: # regular pivotal tuning of clip & t5
params_to_optimize = [
transformer_parameters_with_lr,
text_parameters_one_with_lr,
text_parameters_two_with_lr,
]
te_idx = 1
else:
params_to_optimize = [
transformer_parameters_with_lr,
]
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.optimizer.lower() == "adamw":
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
if args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if not freeze_text_encoder and args.text_encoder_lr:
logger.warning(
f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters to be
# --learning_rate
params_to_optimize[te_idx]["lr"] = args.learning_rate
params_to_optimize[-1]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
train_text_encoder_ti=args.train_text_encoder_ti,
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
class_prompt=args.class_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_num=args.num_class_images,
size=args.resolution,
repeats=args.repeats,
center_crop=args.center_crop,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
if freeze_text_encoder:
tokenizers = [tokenizer_one, tokenizer_two]
text_encoders = [text_encoder_one, text_encoder_two]
def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
text_ids = text_ids.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds, text_ids
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if freeze_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers
)
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
if freeze_text_encoder:
class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings(
args.class_prompt, text_encoders, tokenizers
)
# Clear the memory here
if freeze_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders, text_encoder_one, text_encoder_two
free_memory()
# if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion
add_special_tokens_clip = True if args.train_text_encoder_ti else False
add_special_tokens_t5 = True if (args.train_text_encoder_ti and args.enable_t5_ti) else False
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
if not train_dataset.custom_instance_prompts:
if freeze_text_encoder:
prompt_embeds = instance_prompt_hidden_states
pooled_prompt_embeds = instance_pooled_prompt_embeds
text_ids = instance_text_ids
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
text_ids = torch.cat([text_ids, class_text_ids], dim=0)
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts)
# we need to tokenize and encode the batch prompts on all training steps
else:
tokens_one = tokenize_prompt(
tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens_clip
)
tokens_two = tokenize_prompt(
tokenizer_two,
args.instance_prompt,
max_sequence_length=args.max_sequence_length,
add_special_tokens=add_special_tokens_t5,
)
if args.with_prior_preservation:
class_tokens_one = tokenize_prompt(
tokenizer_one,
args.class_prompt,
max_sequence_length=77,
add_special_tokens=add_special_tokens_clip,
)
class_tokens_two = tokenize_prompt(
tokenizer_two,
args.class_prompt,
max_sequence_length=args.max_sequence_length,
add_special_tokens=add_special_tokens_t5,
)
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_block_out_channels = vae.config.block_out_channels
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if args.validation_prompt is None:
del vae
free_memory()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
if not freeze_text_encoder:
if args.enable_t5_ti:
(
transformer,
text_encoder_one,
text_encoder_two,
optimizer,
train_dataloader,
lr_scheduler,
) = accelerator.prepare(
transformer,
text_encoder_one,
text_encoder_two,
optimizer,
train_dataloader,
lr_scheduler,
)
else:
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
)
else:
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_name = "dreambooth-flux-dev-lora-advanced"
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
args.resume_from_checkpoint = None
initial_global_step = 0
else:
logger.info(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
if args.train_text_encoder:
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs)
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs)
# flag used for textual inversion
pivoted_te = False
pivoted_tr = False
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
# if performing any kind of optimization of text_encoder params
if args.train_text_encoder or args.train_text_encoder_ti:
if epoch == num_train_epochs_text_encoder:
# flag to stop text encoder optimization
logger.info(f"PIVOT TE {epoch}")
pivoted_te = True
else:
# still optimizing the text encoder
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
text_encoder_one.train()
if args.enable_t5_ti:
text_encoder_two.train()
if epoch == num_train_epochs_transformer:
# flag to stop transformer optimization
logger.info(f"PIVOT TRANSFORMER {epoch}")
pivoted_tr = True
for step, batch in enumerate(train_dataloader):
if pivoted_te:
# stopping optimization of text_encoder params
optimizer.param_groups[te_idx]["lr"] = 0.0
optimizer.param_groups[-1]["lr"] = 0.0
elif pivoted_tr and not pure_textual_inversion:
logger.info(f"PIVOT TRANSFORMER {epoch}")
optimizer.param_groups[0]["lr"] = 0.0
with accelerator.accumulate(transformer):
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
)
else:
tokens_one = tokenize_prompt(
tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens_clip
)
tokens_two = tokenize_prompt(
tokenizer_two,
prompts,
max_sequence_length=args.max_sequence_length,
add_special_tokens=add_special_tokens_t5,
)
if not freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=prompts,
)
# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2],
model_input.shape[3],
accelerator.device,
weight_dtype,
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
packed_noisy_model_input = FluxPipeline._pack_latents(
noisy_model_input,
batch_size=model_input.shape[0],
num_channels_latents=model_input.shape[1],
height=model_input.shape[2],
width=model_input.shape[3],
)
# handle guidance
if transformer.config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
guidance = None
# Predict the noise residual
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
vae_scale_factor=vae_scale_factor,
)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss
target = noise - model_input
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute prior loss
prior_loss = torch.mean(
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
target_prior.shape[0], -1
),
1,
)
prior_loss = prior_loss.mean()
# Compute regular loss.
loss = torch.mean(
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
loss = loss.mean()
if args.with_prior_preservation:
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
accelerator.backward(loss)
if accelerator.sync_gradients:
if not freeze_text_encoder:
if args.train_text_encoder:
params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
elif pure_textual_inversion:
params_to_clip = itertools.chain(
text_encoder_one.parameters(), text_encoder_two.parameters()
)
else:
params_to_clip = itertools.chain(
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
)
else:
params_to_clip = itertools.chain(transformer.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# every step, we reset the embeddings to the original embeddings.
if args.train_text_encoder_ti:
embedding_handler.retract_embeddings()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
if freeze_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline_args = {"prompt": args.validation_prompt}
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if freeze_text_encoder:
del text_encoder_one, text_encoder_two
free_memory()
elif args.train_text_encoder:
del text_encoder_two
free_memory()
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
else:
text_encoder_lora_layers = None
if not pure_textual_inversion:
FluxPipeline.save_lora_weights(
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
if args.train_text_encoder_ti:
embeddings_path = f"{args.output_dir}/{os.path.basename(args.output_dir)}_emb.safetensors"
embedding_handler.save_embeddings(embeddings_path)
# Final inference
# Load previous pipeline
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
if not pure_textual_inversion:
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
is_final_validation=True,
)
save_model_card(
model_id if not args.push_to_hub else repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
train_text_encoder_ti=args.train_text_encoder_ti,
enable_t5_ti=args.enable_t5_ti,
pure_textual_inversion=pure_textual_inversion,
token_abstraction_dict=train_dataset.token_abstraction_dict,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
)
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
...@@ -137,7 +137,12 @@ def retrieve_timesteps( ...@@ -137,7 +137,12 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): class FluxPipeline(
DiffusionPipeline,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
):
r""" r"""
The Flux pipeline for text-to-image generation. The Flux pipeline for text-to-image generation.
...@@ -212,6 +217,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): ...@@ -212,6 +217,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2( text_inputs = self.tokenizer_2(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -255,6 +263,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): ...@@ -255,6 +263,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
......
...@@ -25,7 +25,7 @@ from transformers import ( ...@@ -25,7 +25,7 @@ from transformers import (
) )
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel from ...models.transformers import FluxTransformer2DModel
...@@ -238,6 +238,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -238,6 +238,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer_2( text_inputs = self.tokenizer_2(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -281,6 +284,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -281,6 +284,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
......
...@@ -11,7 +11,7 @@ from transformers import ( ...@@ -11,7 +11,7 @@ from transformers import (
) )
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel from ...models.transformers import FluxTransformer2DModel
...@@ -251,6 +251,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -251,6 +251,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2( text_inputs = self.tokenizer_2(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -295,6 +298,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -295,6 +298,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
......
...@@ -12,7 +12,7 @@ from transformers import ( ...@@ -12,7 +12,7 @@ from transformers import (
) )
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel from ...models.transformers import FluxTransformer2DModel
...@@ -261,6 +261,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -261,6 +261,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2( text_inputs = self.tokenizer_2(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -305,6 +308,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -305,6 +308,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
...@@ -235,6 +235,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -235,6 +235,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2( text_inputs = self.tokenizer_2(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -279,6 +282,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -279,6 +282,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
......
...@@ -21,7 +21,7 @@ import torch ...@@ -21,7 +21,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
...@@ -239,6 +239,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -239,6 +239,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2( text_inputs = self.tokenizer_2(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -283,6 +286,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -283,6 +286,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment