Unverified Commit 5ffb73d4 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

let's go Flux2 🚀 (#12711)



* add vae

* Initial commit for Flux 2 Transformer implementation

* add pipeline part

* small edits to the pipeline and conversion

* update conversion script

* fix

* up up

* finish pipeline

* Remove Flux IP Adapter logic for now

* Remove deprecated 3D id logic

* Remove ControlNet logic for now

* Add link to ViT-22B paper as reference for parallel transformer blocks such as the Flux 2 single stream block

* update pipeline

* Don't use biases for input projs and output AdaNorm

* up

* Remove bias for double stream block text QKV projections

* Add script to convert Flux 2 transformer to diffusers

* make style and make quality

* fix a few things.

* allow sft files to go.

* fix image processor

* fix batch

* style a bit

* Fix some bugs in Flux 2 transformer implementation

* Fix dummy input preparation and fix some test bugs

* fix dtype casting in timestep guidance module.

* resolve conflicts.,

* remove ip adapter stuff.

* Fix Flux 2 transformer consistency test

* Fix bug in Flux2TransformerBlock (double stream block)

* Get remaining Flux 2 transformer tests passing

* make style; make quality; make fix-copies

* remove stuff.

* fix type annotaton.

* remove unneeded stuff from tests

* tests

* up

* up

* add sf support

* Remove unused IP Adapter and ControlNet logic from transformer (#9)

* copied from

* Apply suggestions from code review
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarapolinário <joaopaulo.passos@gmail.com>

* up

* up

* up

* up

* up

* Refactor Flux2Attention into separate classes for double stream and single stream attention

* Add _supports_qkv_fusion to AttentionModuleMixin to allow subclasses to disable QKV fusion

* Have Flux2ParallelSelfAttention inherit from AttentionModuleMixin with _supports_qkv_fusion=False

* Log debug message when calling fuse_projections on a AttentionModuleMixin subclass that does not support QKV fusion

* Address review comments

* Update src/diffusers/pipelines/flux2/pipeline_flux2.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* up

* Remove maybe_allow_in_graph decorators for Flux 2 transformer blocks (#12)

* up

* support ostris loras. (#13)

* up

* update schdule

* up

* up (#17)

* add training scripts (#16)

* add training scripts
Co-authored-by: default avatarLinoy Tsaban <linoytsaban@gmail.com>

* model cpu offload in validation.

* add flux.2 readme

* add img2img and tests

* cpu offload in log validation

* Apply suggestions from code review

* fix

* up

* fixes

* remove i2i training tests for now.

---------
Co-authored-by: default avatarLinoy Tsaban <linoytsaban@gmail.com>
Co-authored-by: default avatarlinoytsaban <linoy@huggingface.co>

* up

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
Co-authored-by: default avatarDaniel Gu <dgu8957@gmail.com>
Co-authored-by: default avataryiyi@huggingface.co <yiyi@ip-10-53-87-203.ec2.internal>
Co-authored-by: default avatardg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarapolinário <joaopaulo.passos@gmail.com>
Co-authored-by: default avataryiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
Co-authored-by: default avatarLinoy Tsaban <linoytsaban@gmail.com>
Co-authored-by: default avatarlinoytsaban <linoy@huggingface.co>
parent 4088e8a8
...@@ -349,6 +349,8 @@ ...@@ -349,6 +349,8 @@
title: DiTTransformer2DModel title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d - local: api/models/easyanimate_transformer3d
title: EasyAnimateTransformer3DModel title: EasyAnimateTransformer3DModel
- local: api/models/flux2_transformer
title: Flux2Transformer2DModel
- local: api/models/flux_transformer - local: api/models/flux_transformer
title: FluxTransformer2DModel title: FluxTransformer2DModel
- local: api/models/hidream_image_transformer - local: api/models/hidream_image_transformer
...@@ -525,6 +527,8 @@ ...@@ -525,6 +527,8 @@
title: EasyAnimate title: EasyAnimate
- local: api/pipelines/flux - local: api/pipelines/flux
title: Flux title: Flux
- local: api/pipelines/flux2
title: Flux2
- local: api/pipelines/control_flux_inpaint - local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint title: FluxControlInpaint
- local: api/pipelines/hidream - local: api/pipelines/hidream
......
...@@ -30,7 +30,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi ...@@ -30,7 +30,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4). - [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`]. - [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream) - [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen) - [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. - [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
> [!TIP] > [!TIP]
...@@ -56,6 +57,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi ...@@ -56,6 +57,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
## Flux2LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
## CogVideoXLoraLoaderMixin ## CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
......
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Flux2Transformer2DModel
A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-labs/FLUX.2-dev).
## Flux2Transformer2DModel
[[autodoc]] Flux2Transformer2DModel
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Flux2
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch!
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2-dev).
> [!TIP]
> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.
>
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## Flux2Pipeline
[[autodoc]] Flux2Pipeline
- all
- __call__
\ No newline at end of file
# DreamBooth training example for FLUX.2 [dev]
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2-dev).
> [!NOTE]
> **Memory consumption**
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux2-training)
> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
hf auth login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_flux.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
## Memory Optimizations
> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
> However some techniques may be mutually exclusive so be sure to check before launching a training run.
### Remote Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
This way, the text encoder model is not loaded into memory during training.
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
### QLoRA: Low Precision Training with Quantization
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
- **FP8 training** with `torchao`:
enable FP8 training by passing `--do_fp8_training`.
> [!IMPORTANT] Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater.
> If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers like SimpleTuner, ai-toolkit, etc.
- **NF4 training** with `bitsandbytes`:
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
### Gradient Checkpointing and Accumulation
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
### 8-bit-Adam Optimizer
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
Make sure to install `bitsandbytes` if you want to do so.
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux2"
accelerate launch train_dreambooth_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--remote_text_encoder \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
to use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify -
```bash
--optimizer="prodigy"
```
> [!TIP]
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
To perform DreamBooth with LoRA, run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux2-lora"
accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--remote_text_encoder \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--optimizer="prodigy" \
--learning_rate=1. \
--report_to="wandb" \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
## Training Image-to-Image
Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
**important**
**Important**
To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
To start, you must have a dataset containing triplets:
* Condition image - the input image to be transformed.
* Target image - the desired output image after transformation.
* Instruction - a text prompt describing the transformation from the condition image to the target image.
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
```bash
accelerate launch train_dreambooth_lora_flux2_img2img.py \
--pretrained_model_name_or_path=black-forest-labs/FLUX.2-dev \
--output_dir="flux2-i2i" \
--dataset_name="kontext-community/relighting" \
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
--do_fp8_training \
--gradient_checkpointing \
--remote_text_encoder \
--cache_latents \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--optimizer="adamw" \
--use_8bit_adam \
--cache_latents \
--learning_rate=1e-4 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=200 \
--max_train_steps=1000 \
--rank=16\
--seed="0"
```
More generally, when performing I2I fine-tuning, we expect you to:
* Have a dataset `kontext-community/relighting`
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
### Misc notes
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
### Aspect Ratio Bucketing
we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
`
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAFlux2(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "dog"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2"
script_path = "examples/dreambooth/train_dreambooth_lora_flux2.py"
transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
def test_dreambooth_lora_flux2(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
starts_with_transformer = all(
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--max_sequence_length 8
--checkpointing_steps=2
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 8
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 8
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "diffusers @ git+https://github.com/huggingface/diffusers.git",
# "torch>=2.0.0",
# "accelerate>=0.31.0",
# "transformers>=4.41.2",
# "ftfy",
# "tensorboard",
# "Jinja2",
# "peft>=0.11.1",
# "sentencepiece",
# "torchvision",
# "datasets",
# "bitsandbytes",
# "prodigyopt",
# ]
# ///
import argparse
import copy
import itertools
import json
import logging
import math
import os
import random
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
import numpy as np
import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.auto import tqdm
from transformers import Mistral3ForConditionalGeneration, PixtralProcessor
import diffusers
from diffusers import (
AutoencoderKLFlux2,
BitsAndBytesConfig,
FlowMatchEulerDiscreteScheduler,
Flux2Pipeline,
Flux2Transformer2DModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
offload_models,
parse_buckets_string,
)
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
def save_model_card(
repo_id: str,
images=None,
base_model: str = None,
instance_prompt=None,
validation_prompt=None,
repo_folder=None,
quant_training=None,
):
widget_dict = []
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
)
model_description = f"""
# Flux2 DreamBooth LoRA - {repo_id}
<Gallery />
## Model description
These are {repo_id} DreamBooth LoRA weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md).
Quant training? {quant_training}
## Trigger words
You should use `{instance_prompt}` to trigger the image generation.
## Download model
[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## License
Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="other",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-image",
"diffusers-training",
"diffusers",
"lora",
"flux2",
"flux2-diffusers",
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(dtype=torch_dtype)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=pipeline_args["prompt_embeds"],
generator=generator,
).images[0]
images.append(image)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
phase_name: [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
del pipeline
free_memory()
return images
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the output module
if fqn == "proj_out":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--bnb_quantization_config_path",
type=str,
default=None,
help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
)
parser.add_argument(
"--do_fp8_training",
action="store_true",
help="if we are doing FP8 training.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
help=("A folder containing the training data. "),
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--image_column",
type=str,
default="image",
help="The column of the dataset containing the target image. By "
"default, the standard Image Dataset maps out 'file_name' "
"to 'image'.",
)
parser.add_argument(
"--caption_column",
type=str,
default=None,
help="The column of the dataset containing the instance prompt for each image",
)
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--max_sequence_length",
type=int,
default=512,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument(
"--text_encoder_out_layers",
type=int,
nargs="+",
default=[10, 20, 30],
help="Text encoder hidden layers to compute the final text embeddings.",
)
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--skip_final_inference",
default=False,
action="store_true",
help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.",
)
parser.add_argument(
"--final_validation_prompt",
type=str,
default=None,
help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.",
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=50,
help=(
"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
parser.add_argument(
"--num_class_images",
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="flux-dreambooth-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--aspect_ratio_buckets",
type=str,
default=None,
help=(
"Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. "
"e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'"
"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored."
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
help="the FLUX.1 dev variant is a guidance distilled model",
)
parser.add_argument(
"--text_encoder_lr",
type=float,
default=5e-6,
help="Text encoder learning rate to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--weighting_scheme",
type=str,
default="none",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--optimizer",
type=str,
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--prodigy_beta3",
type=float,
default=None,
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW",
)
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--lora_layers",
type=str,
default=None,
help=(
'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
),
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
)
parser.add_argument(
"--prodigy_use_bias_correction",
type=bool,
default=True,
help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
)
parser.add_argument(
"--prodigy_safeguard_warmup",
type=bool,
default=True,
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
"Ignored if optimizer is adamW",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument(
"--offload",
action="store_true",
help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
)
parser.add_argument(
"--remote_text_encoder",
action="store_true",
help="Whether to use a remote text encoder. This means the text encoder will not be loaded locally and instead, the prompt embeddings will be computed remotely using the HuggingFace Inference API.",
)
parser.add_argument(
"--prior_generation_precision",
type=str,
default=None,
choices=["no", "fp32", "fp16", "bf16"],
help=(
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
if args.dataset_name is None and args.instance_data_dir is None:
raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
if args.dataset_name is not None and args.instance_data_dir is not None:
raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
if args.do_fp8_training and args.bnb_quantization_config_path:
raise ValueError("Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.")
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
if args.with_prior_preservation:
if args.class_data_dir is None:
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
else:
# logger is not available yet
if args.class_data_dir is not None:
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
return args
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images.
"""
def __init__(
self,
instance_data_root,
instance_prompt,
class_prompt,
class_data_root=None,
class_num=None,
size=1024,
repeats=1,
center_crop=False,
buckets=None,
):
self.size = size
self.center_crop = center_crop
self.instance_prompt = instance_prompt
self.custom_instance_prompts = None
self.class_prompt = class_prompt
self.buckets = buckets
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if args.dataset_name is not None:
try:
from datasets import load_dataset
except ImportError:
raise ImportError(
"You are trying to load your data using the datasets library. If you wish to train using custom "
"captions please install the datasets library: `pip install datasets`. If you wish to load a "
"local folder containing images only, specify --instance_data_dir instead."
)
# Downloading and loading a dataset from the hub.
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
# Preprocessing the datasets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
if args.caption_column is None:
logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the "
"column as --caption_column"
)
self.custom_instance_prompts = None
else:
if args.caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
custom_instance_prompts = dataset["train"][args.caption_column]
# create final list of captions according to --repeats
self.custom_instance_prompts = []
for caption in custom_instance_prompts:
self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
else:
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
self.custom_instance_prompts = None
self.instance_images = []
for img in instance_images:
self.instance_images.extend(itertools.repeat(img, repeats))
self.pixel_values = []
for i, image in enumerate(self.instance_images):
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
width, height = image.size
# Find the closest bucket
bucket_idx = find_nearest_bucket(height, width, self.buckets)
target_height, target_width = self.buckets[bucket_idx]
self.size = (target_height, target_width)
# based on the bucket assignment, define the transformations
image = self.train_transform(
image,
size=self.size,
center_crop=args.center_crop,
random_flip=args.random_flip,
)
self.pixel_values.append((image, bucket_idx))
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
if class_num is not None:
self.num_class_images = min(len(self.class_images_path), class_num)
else:
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
if caption:
example["instance_prompt"] = caption
else:
example["instance_prompt"] = self.instance_prompt
else: # custom prompts were provided, but length does not match size of image dataset
example["instance_prompt"] = self.instance_prompt
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image)
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt"] = self.class_prompt
return example
def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False):
# 1. Resize (deterministic)
resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
image = resize(image)
# 2. Crop: either center or SAME random crop
if center_crop:
crop = transforms.CenterCrop(size)
image = crop(image)
else:
# get_params returns (i, j, h, w)
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
image = TF.crop(image, i, j, h, w)
# 3. Random horizontal flip with the SAME coin flip
if random_flip:
do_flip = random.random() < 0.5
if do_flip:
image = TF.hflip(image)
# 4. ToTensor + Normalize (deterministic)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize([0.5], [0.5])
image = normalize(to_tensor(image))
return image
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
pixel_values += [example["class_images"] for example in examples]
prompts += [example["class_prompt"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values, "prompts": prompts}
return batch
class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last))
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
self.bucket_indices[bucket_idx].append(idx)
self.sampler_len = 0
self.batches = []
# Pre-generate batches for each bucket
for indices_in_bucket in self.bucket_indices:
# Shuffle indices within the bucket
random.shuffle(indices_in_bucket)
# Create batches
for i in range(0, len(indices_in_bucket), self.batch_size):
batch = indices_in_bucket[i : i + self.batch_size]
if len(batch) < self.batch_size and self.drop_last:
continue # Skip partial batch if drop_last is True
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches
def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
for batch in self.batches:
yield batch
def __len__(self):
return self.sampler_len
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
if args.do_fp8_training:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16
pipeline = Flux2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
images = pipeline(prompt=example["prompt"]).images
for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline
free_memory()
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
).repo_id
# Load the tokenizers
tokenizer = PixtralProcessor.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="scheduler",
revision=args.revision,
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
vae = AutoencoderKLFlux2.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
accelerator.device
)
quantization_config = None
if args.bnb_quantization_config_path is not None:
with open(args.bnb_quantization_config_path, "r") as f:
config_kwargs = json.load(f)
if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
quantization_config = BitsAndBytesConfig(**config_kwargs)
transformer = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
revision=args.revision,
variant=args.variant,
quantization_config=quantization_config,
torch_dtype=weight_dtype,
)
if args.bnb_quantization_config_path is not None:
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
if not args.remote_text_encoder:
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder.requires_grad_(False)
# We only train the additional adapter LoRA layers
transformer.requires_grad_(False)
vae.requires_grad_(False)
if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
transformer.set_attention_backend("_native_npu")
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype}
# flux vae is stable in bf16 so load it in weight_dtype to reduce memory
vae.to(**to_kwargs)
# we never offload the transformer to CPU, so we can just use the accelerator device
transformer_to_kwargs = (
{"device": accelerator.device}
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
)
if not args.remote_text_encoder:
text_encoder.to(**to_kwargs)
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = Flux2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
tokenizer=tokenizer,
text_encoder=text_encoder,
scheduler=None,
revision=args.revision,
)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
Flux2Pipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [transformer_]
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [transformer]
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
# Optimization parameters
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
params_to_optimize = [transformer_parameters_with_lr]
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.optimizer.lower() == "adamw":
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
if args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
if args.aspect_ratio_buckets is not None:
buckets = parse_buckets_string(args.aspect_ratio_buckets)
else:
buckets = [(args.resolution, args.resolution)]
logger.info(f"Using parsed aspect ratio buckets: {buckets}")
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_num=args.num_class_images,
size=args.resolution,
repeats=args.repeats,
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
def compute_text_embeddings(prompt, text_encoding_pipeline):
with torch.no_grad():
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
prompt=prompt,
max_sequence_length=args.max_sequence_length,
text_encoder_out_layers=args.text_encoder_out_layers,
)
return prompt_embeds, text_ids
def compute_remote_text_embeddings(prompts):
import io
import requests
if args.hub_token is not None:
hf_token = args.hub_token
else:
from huggingface_hub import get_token
hf_token = get_token()
if hf_token is None:
raise ValueError(
"No HuggingFace token found. To use the remote text encoder please login using `hf auth login` or provide a token using --hub_token"
)
def _encode_single(prompt: str):
response = requests.post(
"https://remote-text-encoder-flux-2.huggingface.co/predict",
json={"prompt": prompt},
headers={"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"},
)
assert response.status_code == 200, f"{response.status_code=}"
return torch.load(io.BytesIO(response.content))
try:
if isinstance(prompts, (list, tuple)):
embeds = [_encode_single(p) for p in prompts]
prompt_embeds = torch.cat(embeds, dim=0)
else:
prompt_embeds = _encode_single(prompts)
text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(accelerator.device)
prompt_embeds = prompt_embeds.to(accelerator.device)
return prompt_embeds, text_ids
except Exception as e:
raise RuntimeError("Remote text encoder inference failed.") from e
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
instance_prompt_hidden_states, instance_text_ids = compute_remote_text_embeddings(args.instance_prompt)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings(
args.instance_prompt, text_encoding_pipeline
)
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
if args.remote_text_encoder:
class_prompt_hidden_states, class_text_ids = compute_remote_text_embeddings(args.class_prompt)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
class_prompt_hidden_states, class_text_ids = compute_text_embeddings(
args.class_prompt, text_encoding_pipeline
)
validation_embeddings = {}
if args.validation_prompt is not None:
if args.remote_text_encoder:
(validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = (
compute_remote_text_embeddings(args.validation_prompt)
)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
(validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = compute_text_embeddings(
args.validation_prompt, text_encoding_pipeline
)
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
if not train_dataset.custom_instance_prompts:
prompt_embeds = instance_prompt_hidden_states
text_ids = instance_text_ids
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
text_ids = torch.cat([text_ids, class_text_ids], dim=0)
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
prompt_embeds_cache = []
text_ids_cache = []
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
if args.cache_latents:
with offload_models(vae, device=accelerator.device, offload=args.offload):
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
prompt_embeds_cache.append(prompt_embeds)
text_ids_cache.append(text_ids)
# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
if args.cache_latents:
vae = vae.to("cpu")
del vae
# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
if not args.remote_text_encoder:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
del text_encoder, tokenizer
free_memory()
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_name = "dreambooth-flux2-lora"
args_cp = vars(args).copy()
args_cp["text_encoder_out_layers"] = str(args_cp["text_encoder_out_layers"])
accelerator.init_trackers(tracker_name, config=args_cp)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
prompts = batch["prompts"]
with accelerator.accumulate(models_to_accumulate):
if train_dataset.custom_instance_prompts:
prompt_embeds = prompt_embeds_cache[step]
text_ids = text_ids_cache[step]
else:
num_repeat_elements = len(prompts)
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].mode()
else:
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.mode()
model_input = Flux2Pipeline._patchify_latents(model_input)
model_input = (model_input - latents_bn_mean) / latents_bn_std
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# [B, C, H, W] -> [B, H*W, C]
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
# handle guidance
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
# Predict the noise residual
model_pred = transformer(
hidden_states=packed_noisy_model_input, # (B, image_seq_len, C)
timestep=timesteps / 1000,
guidance=guidance,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, # B, text_seq_len, 4
img_ids=model_input_ids, # B, image_seq_len, 4
return_dict=False,
)[0]
model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss
target = noise - model_input
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute prior loss
prior_loss = torch.mean(
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
target_prior.shape[0], -1
),
1,
)
prior_loss = prior_loss.mean()
# Compute regular loss.
loss = torch.mean(
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
loss = loss.mean()
if args.with_prior_preservation:
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = transformer.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
pipeline = Flux2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=None,
tokenizer=None,
transformer=unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=validation_embeddings,
epoch=epoch,
torch_dtype=weight_dtype,
)
del pipeline
free_memory()
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
**_collate_lora_metadata(modules_to_save),
)
images = []
run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)
should_run_final_inference = not args.skip_final_inference and run_validation
if should_run_final_inference:
pipeline = Flux2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=validation_embeddings,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
images = None
del pipeline
free_memory()
validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
quant_training = None
if args.do_fp8_training:
quant_training = "FP8 TorchAO"
elif args.bnb_quantization_config_path:
quant_training = "BitsandBytes"
save_model_card(
(args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
instance_prompt=args.instance_prompt,
validation_prompt=validation_prompt,
repo_folder=args.output_dir,
quant_training=quant_training,
)
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "diffusers @ git+https://github.com/huggingface/diffusers.git",
# "torch>=2.0.0",
# "accelerate>=0.31.0",
# "transformers>=4.41.2",
# "ftfy",
# "tensorboard",
# "Jinja2",
# "peft>=0.11.1",
# "sentencepiece",
# "torchvision",
# "datasets",
# "bitsandbytes",
# "prodigyopt",
# ]
# ///
import argparse
import copy
import itertools
import json
import logging
import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
import numpy as np
import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.auto import tqdm
from transformers import Mistral3ForConditionalGeneration, PixtralProcessor
import diffusers
from diffusers import (
AutoencoderKLFlux2,
BitsAndBytesConfig,
FlowMatchEulerDiscreteScheduler,
Flux2Pipeline,
Flux2Transformer2DModel,
)
from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
from diffusers.training_utils import (
_collate_lora_metadata,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
offload_models,
parse_buckets_string,
)
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
is_wandb_available,
load_image,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
def save_model_card(
repo_id: str,
images=None,
base_model: str = None,
instance_prompt=None,
validation_prompt=None,
repo_folder=None,
fp8_training=False,
):
widget_dict = []
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
)
model_description = f"""
# Flux DreamBooth LoRA - {repo_id}
<Gallery />
## Model description
These are {repo_id} DreamBooth LoRA weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md).
FP8 training? {fp8_training}
## Trigger words
You should use `{instance_prompt}` to trigger the image generation.
## Download model
[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## License
Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="other",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-image",
"diffusers-training",
"diffusers",
"lora",
"flux2",
"flux2-diffusers",
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(dtype=torch_dtype)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
image=pipeline_args["image"],
prompt_embeds=pipeline_args["prompt_embeds"],
generator=generator,
).images[0]
images.append(image)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
phase_name: [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
del pipeline
free_memory()
return images
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the output module
if fqn == "proj_out":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--bnb_quantization_config_path",
type=str,
default=None,
help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
)
parser.add_argument(
"--do_fp8_training",
action="store_true",
help="if we are doing FP8 training.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
help=("A folder containing the training data. "),
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--image_column",
type=str,
default="image",
help="The column of the dataset containing the target image. By "
"default, the standard Image Dataset maps out 'file_name' "
"to 'image'.",
)
parser.add_argument(
"--cond_image_column",
type=str,
default=None,
help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
)
parser.add_argument(
"--caption_column",
type=str,
default=None,
help="The column of the dataset containing the instance prompt for each image",
)
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
"--max_sequence_length",
type=int,
default=512,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
help="path to an image that is used during validation as the condition image to verify that the model is learning.",
)
parser.add_argument(
"--skip_final_inference",
default=False,
action="store_true",
help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.",
)
parser.add_argument(
"--final_validation_prompt",
type=str,
default=None,
help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.",
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=50,
help=(
"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--output_dir",
type=str,
default="flux-dreambooth-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--aspect_ratio_buckets",
type=str,
default=None,
help=(
"Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. "
"e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'"
"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored."
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
help="the FLUX.1 dev variant is a guidance distilled model",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--weighting_scheme",
type=str,
default="none",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--optimizer",
type=str,
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--prodigy_beta3",
type=float,
default=None,
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW",
)
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--lora_layers",
type=str,
default=None,
help=(
'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
),
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
)
parser.add_argument(
"--prodigy_use_bias_correction",
type=bool,
default=True,
help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
)
parser.add_argument(
"--prodigy_safeguard_warmup",
type=bool,
default=True,
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
"Ignored if optimizer is adamW",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument(
"--offload",
action="store_true",
help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
)
parser.add_argument(
"--remote_text_encoder",
action="store_true",
help="Whether to use a remote text encoder. This means the text encoder will not be loaded locally and instead, the prompt embeddings will be computed remotely using the HuggingFace Inference API.",
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
if args.cond_image_column is None:
raise ValueError(
"you must provide --cond_image_column for image-to-image training. Otherwise please see Flux2 text-to-image training example."
)
else:
assert args.image_column is not None
assert args.caption_column is not None
if args.dataset_name is None and args.instance_data_dir is None:
raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
if args.dataset_name is not None and args.instance_data_dir is not None:
raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images.
"""
def __init__(
self,
instance_data_root,
instance_prompt,
size=1024,
repeats=1,
center_crop=False,
buckets=None,
):
self.size = size
self.center_crop = center_crop
self.instance_prompt = instance_prompt
self.custom_instance_prompts = None
self.buckets = buckets
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if args.dataset_name is not None:
try:
from datasets import load_dataset
except ImportError:
raise ImportError(
"You are trying to load your data using the datasets library. If you wish to train using custom "
"captions please install the datasets library: `pip install datasets`. If you wish to load a "
"local folder containing images only, specify --instance_data_dir instead."
)
# Downloading and loading a dataset from the hub.
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
# Preprocessing the datasets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.cond_image_column is not None and args.cond_image_column not in column_names:
raise ValueError(
f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
cond_images = None
cond_image_column = args.cond_image_column
if cond_image_column is not None:
cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
assert len(instance_images) == len(cond_images)
if args.caption_column is None:
logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the "
"column as --caption_column"
)
self.custom_instance_prompts = None
else:
if args.caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
custom_instance_prompts = dataset["train"][args.caption_column]
# create final list of captions according to --repeats
self.custom_instance_prompts = []
for caption in custom_instance_prompts:
self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
else:
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
self.custom_instance_prompts = None
self.instance_images = []
self.cond_images = []
for i, img in enumerate(instance_images):
self.instance_images.extend(itertools.repeat(img, repeats))
if args.dataset_name is not None and cond_images is not None:
self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
self.pixel_values = []
self.cond_pixel_values = []
for i, image in enumerate(self.instance_images):
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
dest_image = None
if self.cond_images: # todo: take care of max area for buckets
dest_image = self.cond_images[i]
image_width, image_height = dest_image.size
if image_width * image_height > 1024 * 1024:
dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
image_width, image_height = dest_image.size
multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
dest_image = Flux2ImageProcessor.image_processor.preprocess(
dest_image, height=image_height, width=image_width, resize_mode="crop"
)
dest_image = exif_transpose(dest_image)
if not dest_image.mode == "RGB":
dest_image = dest_image.convert("RGB")
width, height = image.size
# Find the closest bucket
bucket_idx = find_nearest_bucket(height, width, self.buckets)
target_height, target_width = self.buckets[bucket_idx]
self.size = (target_height, target_width)
# based on the bucket assignment, define the transformations
image, dest_image = self.paired_transform(
image,
dest_image=dest_image,
size=self.size,
center_crop=args.center_crop,
random_flip=args.random_flip,
)
self.pixel_values.append((image, bucket_idx))
if dest_image is not None:
self.cond_pixel_values.append((dest_image, bucket_idx))
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx
if self.cond_pixel_values:
dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
example["cond_images"] = dest_image
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
if caption:
example["instance_prompt"] = caption
else:
example["instance_prompt"] = self.instance_prompt
else: # custom prompts were provided, but length does not match size of image dataset
example["instance_prompt"] = self.instance_prompt
return example
def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
# 1. Resize (deterministic)
resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
image = resize(image)
if dest_image is not None:
dest_image = resize(dest_image)
# 2. Crop: either center or SAME random crop
if center_crop:
crop = transforms.CenterCrop(size)
image = crop(image)
if dest_image is not None:
dest_image = crop(dest_image)
else:
# get_params returns (i, j, h, w)
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
image = TF.crop(image, i, j, h, w)
if dest_image is not None:
dest_image = TF.crop(dest_image, i, j, h, w)
# 3. Random horizontal flip with the SAME coin flip
if random_flip:
do_flip = random.random() < 0.5
if do_flip:
image = TF.hflip(image)
if dest_image is not None:
dest_image = TF.hflip(dest_image)
# 4. ToTensor + Normalize (deterministic)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize([0.5], [0.5])
image = normalize(to_tensor(image))
if dest_image is not None:
dest_image = normalize(to_tensor(dest_image))
return (image, dest_image) if dest_image is not None else (image, None)
def collate_fn(examples):
pixel_values = [example["instance_images"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values, "prompts": prompts}
if any("cond_images" in example for example in examples):
cond_pixel_values = [example["cond_images"] for example in examples]
cond_pixel_values = torch.stack(cond_pixel_values)
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
batch.update({"cond_pixel_values": cond_pixel_values})
return batch
class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last))
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
self.bucket_indices[bucket_idx].append(idx)
self.sampler_len = 0
self.batches = []
# Pre-generate batches for each bucket
for indices_in_bucket in self.bucket_indices:
# Shuffle indices within the bucket
random.shuffle(indices_in_bucket)
# Create batches
for i in range(0, len(indices_in_bucket), self.batch_size):
batch = indices_in_bucket[i : i + self.batch_size]
if len(batch) < self.batch_size and self.drop_last:
continue # Skip partial batch if drop_last is True
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches
def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
for batch in self.batches:
yield batch
def __len__(self):
return self.sampler_len
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
if args.do_fp8_training:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
).repo_id
# Load the tokenizers
tokenizer = PixtralProcessor.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="scheduler",
revision=args.revision,
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
vae = AutoencoderKLFlux2.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
accelerator.device
)
quantization_config = None
if args.bnb_quantization_config_path is not None:
with open(args.bnb_quantization_config_path, "r") as f:
config_kwargs = json.load(f)
if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
quantization_config = BitsAndBytesConfig(**config_kwargs)
transformer = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
revision=args.revision,
variant=args.variant,
quantization_config=quantization_config,
torch_dtype=weight_dtype,
)
if args.bnb_quantization_config_path is not None:
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
if not args.remote_text_encoder:
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder.requires_grad_(False)
# We only train the additional adapter LoRA layers
transformer.requires_grad_(False)
vae.requires_grad_(False)
if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
transformer.set_attention_backend("_native_npu")
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype}
# flux vae is stable in bf16 so load it in weight_dtype to reduce memory
vae.to(**to_kwargs)
# we never offload the transformer to CPU, so we can just use the accelerator device
transformer_to_kwargs = (
{"device": accelerator.device}
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
)
if not args.remote_text_encoder:
text_encoder.to(**to_kwargs)
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = Flux2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
tokenizer=tokenizer,
text_encoder=text_encoder,
scheduler=None,
revision=args.revision,
)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
Flux2Pipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [transformer_]
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [transformer]
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
# Optimization parameters
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
params_to_optimize = [transformer_parameters_with_lr]
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.optimizer.lower() == "adamw":
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
if args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
if args.aspect_ratio_buckets is not None:
buckets = parse_buckets_string(args.aspect_ratio_buckets)
else:
buckets = [(args.resolution, args.resolution)]
logger.info(f"Using parsed aspect ratio buckets: {buckets}")
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
size=args.resolution,
repeats=args.repeats,
center_crop=args.center_crop,
buckets=buckets,
)
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=lambda examples: collate_fn(examples),
num_workers=args.dataloader_num_workers,
)
def compute_text_embeddings(prompt, text_encoding_pipeline):
with torch.no_grad():
prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
prompt=prompt, max_sequence_length=args.max_sequence_length
)
# prompt_embeds = prompt_embeds.to(accelerator.device)
# text_ids = text_ids.to(accelerator.device)
return prompt_embeds, text_ids
def compute_remote_text_embeddings(prompts: str | list[str]):
import io
import requests
if args.hub_token is not None:
hf_token = args.hub_token
else:
from huggingface_hub import get_token
hf_token = get_token()
if hf_token is None:
raise ValueError(
"No HuggingFace token found. To use the remote text encoder please login using `hf auth login` or provide a token using --hub_token"
)
def _encode_single(prompt: str):
response = requests.post(
"https://remote-text-encoder-flux-2.huggingface.co/predict",
json={"prompt": prompt},
headers={"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"},
)
assert response.status_code == 200, f"{response.status_code=}"
return torch.load(io.BytesIO(response.content))
try:
if isinstance(prompts, (list, tuple)):
embeds = [_encode_single(p) for p in prompts]
prompt_embeds = torch.cat(embeds, dim=0).to(accelerator.device)
else:
prompt_embeds = _encode_single(prompts).to(accelerator.device)
text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(accelerator.device)
return prompt_embeds, text_ids
except Exception as e:
raise RuntimeError("Remote text encoder inference failed.") from e
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
instance_prompt_hidden_states, instance_text_ids = compute_remote_text_embeddings(args.instance_prompt)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings(
args.instance_prompt, text_encoding_pipeline
)
validation_image = load_image(args.validation_image_path).convert("RGB")
validation_kwargs = {"image": validation_image}
if args.validation_prompt is not None:
if args.remote_text_encoder:
validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
validation_kwargs["prompt_embeds"] = compute_text_embeddings(
args.validation_prompt, text_encoding_pipeline
)
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
if not train_dataset.custom_instance_prompts:
prompt_embeds = instance_prompt_hidden_states
text_ids = instance_text_ids
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
prompt_embeds_cache = []
text_ids_cache = []
latents_cache = []
cond_latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
if args.cache_latents:
with offload_models(vae, device=accelerator.device, offload=args.offload):
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype
)
cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
prompt_embeds_cache.append(prompt_embeds)
text_ids_cache.append(text_ids)
# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
if args.cache_latents:
vae = vae.to("cpu")
del vae
# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
if not args.remote_text_encoder:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
del text_encoder, tokenizer
free_memory()
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_name = "dreambooth-flux2-image2img-lora"
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
prompts = batch["prompts"]
with accelerator.accumulate(models_to_accumulate):
if train_dataset.custom_instance_prompts:
prompt_embeds = prompt_embeds_cache[step]
text_ids = text_ids_cache[step]
else:
num_repeat_elements = len(prompts)
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].mode()
cond_model_input = cond_latents_cache[step].mode()
else:
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.mode()
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
# model_input = Flux2Pipeline._encode_vae_image(pixel_values)
model_input = Flux2Pipeline._patchify_latents(model_input)
model_input = (model_input - latents_bn_mean) / latents_bn_std
cond_model_input = Flux2Pipeline._patchify_latents(cond_model_input)
cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
device=cond_model_input.device
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# [B, C, H, W] -> [B, H*W, C]
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
# concatenate the model inputs with the cond inputs
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
# handle guidance
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
# Predict the noise residual
model_pred = transformer(
hidden_states=packed_noisy_model_input, # (B, image_seq_len, C)
timestep=timesteps / 1000,
guidance=guidance,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, # B, text_seq_len, 4
img_ids=model_input_ids, # B, image_seq_len, 4
return_dict=False,
)[0]
model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss
target = noise - model_input
# Compute regular loss.
loss = torch.mean(
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
loss = loss.mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = transformer.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
pipeline = Flux2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=None,
tokenizer=None,
transformer=unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=validation_kwargs,
epoch=epoch,
torch_dtype=weight_dtype,
)
del pipeline
free_memory()
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
**_collate_lora_metadata(modules_to_save),
)
images = []
run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)
should_run_final_inference = not args.skip_final_inference and run_validation
if should_run_final_inference:
pipeline = Flux2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=validation_kwargs,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
del pipeline
free_memory()
validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
save_model_card(
(args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
instance_prompt=args.instance_prompt,
validation_prompt=validation_prompt,
repo_folder=args.output_dir,
fp8_training=args.do_fp8_training,
)
if args.push_to_hub:
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)
import argparse
from contextlib import nullcontext
from typing import Any, Dict, Tuple
import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, GenerationConfig, Mistral3ForConditionalGeneration
from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
from diffusers.utils.import_utils import is_accelerate_available
"""
# VAE
python scripts/convert_flux2_to_diffusers.py \
--original_state_dict_repo_id "diffusers-internal-dev/new-model-image" \
--vae_filename "flux2-vae.sft" \
--output_path "/raid/yiyi/dummy-flux2-diffusers" \
--vae
# DiT
python scripts/convert_flux2_to_diffusers.py \
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
--dit_filename flux-dev-dummy.sft \
--dit \
--output_path .
# Full pipe
python scripts/convert_flux2_to_diffusers.py \
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
--dit_filename flux-dev-dummy.sft \
--vae_filename "flux2-vae.sft" \
--dit --vae --full_pipe \
--output_path .
"""
CTX = init_empty_weights if is_accelerate_available() else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--dit", action="store_true")
parser.add_argument("--vae_dtype", type=str, default="fp32")
parser.add_argument("--dit_dtype", type=str, default="bf16")
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--full_pipe", action="store_true")
parser.add_argument("--output_path", type=str)
args = parser.parse_args()
def load_original_checkpoint(args, filename):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
DIFFUSERS_VAE_TO_FLUX2_MAPPING = {
"encoder.conv_in.weight": "encoder.conv_in.weight",
"encoder.conv_in.bias": "encoder.conv_in.bias",
"encoder.conv_out.weight": "encoder.conv_out.weight",
"encoder.conv_out.bias": "encoder.conv_out.bias",
"encoder.conv_norm_out.weight": "encoder.norm_out.weight",
"encoder.conv_norm_out.bias": "encoder.norm_out.bias",
"decoder.conv_in.weight": "decoder.conv_in.weight",
"decoder.conv_in.bias": "decoder.conv_in.bias",
"decoder.conv_out.weight": "decoder.conv_out.weight",
"decoder.conv_out.bias": "decoder.conv_out.bias",
"decoder.conv_norm_out.weight": "decoder.norm_out.weight",
"decoder.conv_norm_out.bias": "decoder.norm_out.bias",
"quant_conv.weight": "encoder.quant_conv.weight",
"quant_conv.bias": "encoder.quant_conv.bias",
"post_quant_conv.weight": "decoder.post_quant_conv.weight",
"post_quant_conv.bias": "decoder.post_quant_conv.bias",
"bn.running_mean": "bn.running_mean",
"bn.running_var": "bn.running_var",
}
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = (
ldm_key.replace(mapping["old"], mapping["new"])
.replace("norm.weight", "group_norm.weight")
.replace("norm.bias", "group_norm.bias")
.replace("q.weight", "to_q.weight")
.replace("q.bias", "to_q.bias")
.replace("k.weight", "to_k.weight")
.replace("k.bias", "to_k.bias")
.replace("v.weight", "to_v.weight")
.replace("v.bias", "to_v.bias")
.replace("proj_out.weight", "to_out.0.weight")
.replace("proj_out.bias", "to_out.0.bias")
)
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
# proj_attn.weight has to be converted from conv 1D to linear
shape = new_checkpoint[diffusers_key].shape
if len(shape) == 3:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
elif len(shape) == 4:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
def convert_flux2_vae_checkpoint_to_diffusers(vae_state_dict, config):
new_checkpoint = {}
for diffusers_key, ldm_key in DIFFUSERS_VAE_TO_FLUX2_MAPPING.items():
if ldm_key not in vae_state_dict:
continue
new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len(config["down_block_types"])
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
)
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
f"encoder.down.{i}.downsample.conv.bias"
)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
update_vae_attentions_ldm_to_diffusers(
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
)
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len(config["up_block_types"])
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
)
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
update_vae_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
vae_state_dict,
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
update_vae_attentions_ldm_to_diffusers(
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
# Image and text input projections
"img_in": "x_embedder",
"txt_in": "context_embedder",
# Timestep and guidance embeddings
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
# Modulation parameters
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
"single_stream_modulation.lin": "single_stream_modulation.linear",
# Final output layer
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
"final_layer.linear": "proj_out",
}
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
"final_layer.adaLN_modulation.1": "norm_out.linear",
}
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
# Handle fused QKV projections separately as we need to break into Q, K, V projections
"img_attn.norm.query_norm": "attn.norm_q",
"img_attn.norm.key_norm": "attn.norm_k",
"img_attn.proj": "attn.to_out.0",
"img_mlp.0": "ff.linear_in",
"img_mlp.2": "ff.linear_out",
"txt_attn.norm.query_norm": "attn.norm_added_q",
"txt_attn.norm.key_norm": "attn.norm_added_k",
"txt_attn.proj": "attn.to_add_out",
"txt_mlp.0": "ff_context.linear_in",
"txt_mlp.2": "ff_context.linear_out",
}
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
"linear1": "attn.to_qkv_mlp_proj",
"linear2": "attn.to_out",
"norm.query_norm": "attn.norm_q",
"norm.key_norm": "attn.norm_k",
}
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use
# diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight
if ".weight" not in key:
return
# If adaLN_modulation is in the key, swap scale and shift parameters
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
if "adaLN_modulation" in key:
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
# Assume all such keys are in the AdaLayerNorm key map
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
new_key = ".".join([new_key_without_param_type, param_type])
swapped_weight = swap_scale_shift(state_dict.pop(key))
state_dict[new_key] = swapped_weight
return
def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias, or scale
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
return
new_prefix = "transformer_blocks"
if "double_blocks." in key:
parts = key.split(".")
block_idx = parts[1]
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
within_block_name = ".".join(parts[2:-1])
param_type = parts[-1]
if param_type == "scale":
param_type = "weight"
if "qkv" in within_block_name:
fused_qkv_weight = state_dict.pop(key)
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
if "img" in modality_block_name:
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = "attn.to_q"
new_k_name = "attn.to_k"
new_v_name = "attn.to_v"
elif "txt" in modality_block_name:
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = "attn.add_q_proj"
new_k_name = "attn.add_k_proj"
new_v_name = "attn.add_v_proj"
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
state_dict[new_q_key] = to_q_weight
state_dict[new_k_key] = to_k_weight
state_dict[new_v_key] = to_v_weight
else:
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias, or scale
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
return
# Mapping:
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
new_prefix = "single_transformer_blocks"
if "single_blocks." in key:
parts = key.split(".")
block_idx = parts[1]
within_block_name = ".".join(parts[2:-1])
param_type = parts[-1]
if param_type == "scale":
param_type = "weight"
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
param = state_dict.pop(key)
state_dict[new_key] = param
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"adaLN_modulation": convert_ada_layer_norm_weights,
"double_blocks": convert_flux2_double_stream_blocks,
"single_blocks": convert_flux2_single_stream_blocks,
}
def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
if model_type == "test" or model_type == "dummy-flux2":
config = {
"model_id": "diffusers-internal-dev/dummy-flux2",
"diffusers_config": {
"patch_size": 1,
"in_channels": 128,
"num_layers": 8,
"num_single_layers": 48,
"attention_head_dim": 128,
"num_attention_heads": 48,
"joint_attention_dim": 15360,
"timestep_guidance_channels": 256,
"mlp_ratio": 3.0,
"axes_dims_rope": (32, 32, 32, 32),
"rope_theta": 2000,
"eps": 1e-6,
},
}
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str):
config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
transformer = Flux2Transformer2DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def main(args):
if args.vae:
original_vae_ckpt = load_original_checkpoint(args, filename=args.vae_filename)
vae = AutoencoderKLFlux2()
converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_vae_ckpt, vae.config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if not args.full_pipe:
vae_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
vae.to(vae_dtype).save_pretrained(f"{args.output_path}/vae")
if args.dit:
original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
if not args.full_pipe:
dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
if args.full_pipe:
tokenizer_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
text_encoder_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
generate_config = GenerationConfig.from_pretrained(text_encoder_id)
generate_config.do_sample = True
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
text_encoder_id, generation_config=generate_config, torch_dtype=torch.bfloat16
)
tokenizer = AutoProcessor.from_pretrained(tokenizer_id)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
)
pipe = Flux2Pipeline(
vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
)
pipe.save_pretrained(args.output_path)
if __name__ == "__main__":
main(args)
...@@ -186,6 +186,7 @@ else: ...@@ -186,6 +186,7 @@ else:
"AutoencoderKLAllegro", "AutoencoderKLAllegro",
"AutoencoderKLCogVideoX", "AutoencoderKLCogVideoX",
"AutoencoderKLCosmos", "AutoencoderKLCosmos",
"AutoencoderKLFlux2",
"AutoencoderKLHunyuanImage", "AutoencoderKLHunyuanImage",
"AutoencoderKLHunyuanImageRefiner", "AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo",
...@@ -215,6 +216,7 @@ else: ...@@ -215,6 +216,7 @@ else:
"CosmosTransformer3DModel", "CosmosTransformer3DModel",
"DiTTransformer2DModel", "DiTTransformer2DModel",
"EasyAnimateTransformer3DModel", "EasyAnimateTransformer3DModel",
"Flux2Transformer2DModel",
"FluxControlNetModel", "FluxControlNetModel",
"FluxMultiControlNetModel", "FluxMultiControlNetModel",
"FluxTransformer2DModel", "FluxTransformer2DModel",
...@@ -458,6 +460,7 @@ else: ...@@ -458,6 +460,7 @@ else:
"EasyAnimateControlPipeline", "EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline", "EasyAnimateInpaintPipeline",
"EasyAnimatePipeline", "EasyAnimatePipeline",
"Flux2Pipeline",
"FluxControlImg2ImgPipeline", "FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline", "FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline", "FluxControlNetImg2ImgPipeline",
...@@ -902,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -902,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro, AutoencoderKLAllegro,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
AutoencoderKLCosmos, AutoencoderKLCosmos,
AutoencoderKLFlux2,
AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo,
...@@ -931,6 +935,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -931,6 +935,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CosmosTransformer3DModel, CosmosTransformer3DModel,
DiTTransformer2DModel, DiTTransformer2DModel,
EasyAnimateTransformer3DModel, EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxControlNetModel, FluxControlNetModel,
FluxMultiControlNetModel, FluxMultiControlNetModel,
FluxTransformer2DModel, FluxTransformer2DModel,
...@@ -1143,6 +1148,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -1143,6 +1148,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateControlPipeline, EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline, EasyAnimateInpaintPipeline,
EasyAnimatePipeline, EasyAnimatePipeline,
Flux2Pipeline,
FluxControlImg2ImgPipeline, FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline, FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline, FluxControlNetImg2ImgPipeline,
......
...@@ -81,6 +81,7 @@ if is_torch_available(): ...@@ -81,6 +81,7 @@ if is_torch_available():
"HiDreamImageLoraLoaderMixin", "HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin", "SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin", "QwenImageLoraLoaderMixin",
"Flux2LoraLoaderMixin",
] ]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [ _import_structure["ip_adapter"] = [
...@@ -113,6 +114,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -113,6 +114,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AuraFlowLoraLoaderMixin, AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin, CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin, CogView4LoraLoaderMixin,
Flux2LoraLoaderMixin,
FluxLoraLoaderMixin, FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin, HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin, HunyuanVideoLoraLoaderMixin,
......
...@@ -2265,3 +2265,89 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict): ...@@ -2265,3 +2265,89 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict return converted_state_dict
def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
converted_state_dict = {}
prefix = "diffusion_model."
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
num_double_layers = 8
num_single_layers = 48
lora_keys = ("lora_A", "lora_B")
attn_types = ("img_attn", "txt_attn")
for sl in range(num_single_layers):
single_block_prefix = f"single_blocks.{sl}"
attn_prefix = f"single_transformer_blocks.{sl}.attn"
for lora_key in lora_keys:
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear1.{lora_key}.weight"
)
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear2.{lora_key}.weight"
)
for dl in range(num_double_layers):
transformer_block_prefix = f"transformer_blocks.{dl}"
for lora_key in lora_keys:
for attn_type in attn_types:
attn_prefix = f"{transformer_block_prefix}.attn"
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
fused_qkv_weight = original_state_dict.pop(qkv_key)
if lora_key == "lora_A":
diff_attn_proj_keys = (
["to_q", "to_k", "to_v"]
if attn_type == "img_attn"
else ["add_q_proj", "add_k_proj", "add_v_proj"]
)
for proj_key in diff_attn_proj_keys:
converted_state_dict[f"{attn_prefix}.{proj_key}.{lora_key}.weight"] = torch.cat(
[fused_qkv_weight]
)
else:
sample_q, sample_k, sample_v = torch.chunk(fused_qkv_weight, 3, dim=0)
if attn_type == "img_attn":
converted_state_dict[f"{attn_prefix}.to_q.{lora_key}.weight"] = torch.cat([sample_q])
converted_state_dict[f"{attn_prefix}.to_k.{lora_key}.weight"] = torch.cat([sample_k])
converted_state_dict[f"{attn_prefix}.to_v.{lora_key}.weight"] = torch.cat([sample_v])
else:
converted_state_dict[f"{attn_prefix}.add_q_proj.{lora_key}.weight"] = torch.cat([sample_q])
converted_state_dict[f"{attn_prefix}.add_k_proj.{lora_key}.weight"] = torch.cat([sample_k])
converted_state_dict[f"{attn_prefix}.add_v_proj.{lora_key}.weight"] = torch.cat([sample_v])
proj_mappings = [
("img_attn.proj", "attn.to_out.0"),
("txt_attn.proj", "attn.to_add_out"),
]
for org_proj, diff_proj in proj_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
mlp_mappings = [
("img_mlp.0", "ff.linear_in"),
("img_mlp.2", "ff.linear_out"),
("txt_mlp.0", "ff_context.linear_in"),
("txt_mlp.2", "ff_context.linear_out"),
]
for org_mlp, diff_mlp in mlp_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
...@@ -45,6 +45,7 @@ from .lora_conversion_utils import ( ...@@ -45,6 +45,7 @@ from .lora_conversion_utils import (
_convert_hunyuan_video_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers, _convert_non_diffusers_ltxv_lora_to_diffusers,
...@@ -5084,6 +5085,209 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5084,6 +5085,209 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs) super().unfuse_lora(components=components, **kwargs)
class Flux2LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
if is_ai_toolkit:
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
......
...@@ -62,6 +62,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { ...@@ -62,6 +62,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"WanVACETransformer3DModel": lambda model_cls, weights: weights, "WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights, "ChromaTransformer2DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights, "QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
} }
......
...@@ -34,6 +34,7 @@ from .single_file_utils import ( ...@@ -34,6 +34,7 @@ from .single_file_utils import (
convert_chroma_transformer_checkpoint_to_diffusers, convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint, convert_controlnet_checkpoint,
convert_cosmos_transformer_checkpoint_to_diffusers, convert_cosmos_transformer_checkpoint_to_diffusers,
convert_flux2_transformer_checkpoint_to_diffusers,
convert_flux_transformer_checkpoint_to_diffusers, convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers, convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers, convert_hunyuan_video_transformer_to_diffusers,
...@@ -162,6 +163,10 @@ SINGLE_FILE_LOADABLE_CLASSES = { ...@@ -162,6 +163,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": lambda x: x, "checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer", "default_subfolder": "transformer",
}, },
"Flux2Transformer2DModel": {
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
} }
......
...@@ -140,6 +140,7 @@ CHECKPOINT_KEY_NAMES = { ...@@ -140,6 +140,7 @@ CHECKPOINT_KEY_NAMES = {
"net.blocks.0.self_attn.q_proj.weight", "net.blocks.0.self_attn.q_proj.weight",
"net.pos_embedder.dim_spatial_range", "net.pos_embedder.dim_spatial_range",
], ],
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
} }
DIFFUSERS_DEFAULT_PIPELINE_PATHS = { DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
...@@ -189,6 +190,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = { ...@@ -189,6 +190,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"}, "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"}, "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"}, "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
...@@ -649,6 +651,9 @@ def infer_diffusers_model_type(checkpoint): ...@@ -649,6 +651,9 @@ def infer_diffusers_model_type(checkpoint):
else: else:
model_type = "animatediff_v3" model_type = "animatediff_v3"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux2"]):
model_type = "flux-2-dev"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
if any( if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
...@@ -3647,3 +3652,168 @@ def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): ...@@ -3647,3 +3652,168 @@ def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, converted_state_dict) handler_fn_inplace(key, converted_state_dict)
return converted_state_dict return converted_state_dict
def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
# Image and text input projections
"img_in": "x_embedder",
"txt_in": "context_embedder",
# Timestep and guidance embeddings
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
# Modulation parameters
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
"single_stream_modulation.lin": "single_stream_modulation.linear",
# Final output layer
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
"final_layer.linear": "proj_out",
}
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
"final_layer.adaLN_modulation.1": "norm_out.linear",
}
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
# Handle fused QKV projections separately as we need to break into Q, K, V projections
"img_attn.norm.query_norm": "attn.norm_q",
"img_attn.norm.key_norm": "attn.norm_k",
"img_attn.proj": "attn.to_out.0",
"img_mlp.0": "ff.linear_in",
"img_mlp.2": "ff.linear_out",
"txt_attn.norm.query_norm": "attn.norm_added_q",
"txt_attn.norm.key_norm": "attn.norm_added_k",
"txt_attn.proj": "attn.to_add_out",
"txt_mlp.0": "ff_context.linear_in",
"txt_mlp.2": "ff_context.linear_out",
}
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
"linear1": "attn.to_qkv_mlp_proj",
"linear2": "attn.to_out",
"norm.query_norm": "attn.norm_q",
"norm.key_norm": "attn.norm_k",
}
def convert_flux2_single_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
# Skip if not a weight, bias, or scale
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
return
# Mapping:
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
new_prefix = "single_transformer_blocks"
if "single_blocks." in key:
parts = key.split(".")
block_idx = parts[1]
within_block_name = ".".join(parts[2:-1])
param_type = parts[-1]
if param_type == "scale":
param_type = "weight"
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_ada_layer_norm_weights(key: str, state_dict: dict[str, object]) -> None:
# Skip if not a weight
if ".weight" not in key:
return
# If adaLN_modulation is in the key, swap scale and shift parameters
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
if "adaLN_modulation" in key:
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
# Assume all such keys are in the AdaLayerNorm key map
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
new_key = ".".join([new_key_without_param_type, param_type])
swapped_weight = swap_scale_shift(state_dict.pop(key), 0)
state_dict[new_key] = swapped_weight
return
def convert_flux2_double_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
# Skip if not a weight, bias, or scale
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
return
new_prefix = "transformer_blocks"
if "double_blocks." in key:
parts = key.split(".")
block_idx = parts[1]
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
within_block_name = ".".join(parts[2:-1])
param_type = parts[-1]
if param_type == "scale":
param_type = "weight"
if "qkv" in within_block_name:
fused_qkv_weight = state_dict.pop(key)
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
if "img" in modality_block_name:
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = "attn.to_q"
new_k_name = "attn.to_k"
new_v_name = "attn.to_v"
elif "txt" in modality_block_name:
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = "attn.add_q_proj"
new_k_name = "attn.add_k_proj"
new_v_name = "attn.add_v_proj"
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
state_dict[new_q_key] = to_q_weight
state_dict[new_k_key] = to_k_weight
state_dict[new_v_key] = to_v_weight
else:
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
param = state_dict.pop(key)
state_dict[new_key] = param
return
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"adaLN_modulation": convert_ada_layer_norm_weights,
"double_blocks": convert_flux2_double_stream_blocks,
"single_blocks": convert_flux2_single_stream_blocks,
}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
...@@ -35,6 +35,7 @@ if is_torch_available(): ...@@ -35,6 +35,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"] _import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_flux2"] = ["AutoencoderKLFlux2"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
...@@ -92,6 +93,7 @@ if is_torch_available(): ...@@ -92,6 +93,7 @@ if is_torch_available():
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"] _import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
...@@ -141,6 +143,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -141,6 +143,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro, AutoencoderKLAllegro,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
AutoencoderKLCosmos, AutoencoderKLCosmos,
AutoencoderKLFlux2,
AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo,
...@@ -191,6 +194,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -191,6 +194,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
DiTTransformer2DModel, DiTTransformer2DModel,
DualTransformer2DModel, DualTransformer2DModel,
EasyAnimateTransformer3DModel, EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxTransformer2DModel, FluxTransformer2DModel,
HiDreamImageTransformer2DModel, HiDreamImageTransformer2DModel,
HunyuanDiT2DModel, HunyuanDiT2DModel,
......
...@@ -105,7 +105,7 @@ class AttentionMixin: ...@@ -105,7 +105,7 @@ class AttentionMixin:
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
for module in self.modules(): for module in self.modules():
if isinstance(module, AttentionModuleMixin): if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
module.fuse_projections() module.fuse_projections()
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
...@@ -114,13 +114,14 @@ class AttentionMixin: ...@@ -114,13 +114,14 @@ class AttentionMixin:
> [!WARNING] > This API is 🧪 experimental. > [!WARNING] > This API is 🧪 experimental.
""" """
for module in self.modules(): for module in self.modules():
if isinstance(module, AttentionModuleMixin): if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
module.unfuse_projections() module.unfuse_projections()
class AttentionModuleMixin: class AttentionModuleMixin:
_default_processor_cls = None _default_processor_cls = None
_available_processors = [] _available_processors = []
_supports_qkv_fusion = True
fused_projections = False fused_projections = False
def set_processor(self, processor: AttentionProcessor) -> None: def set_processor(self, processor: AttentionProcessor) -> None:
...@@ -248,6 +249,14 @@ class AttentionModuleMixin: ...@@ -248,6 +249,14 @@ class AttentionModuleMixin:
""" """
Fuse the query, key, and value projections into a single projection for efficiency. Fuse the query, key, and value projections into a single projection for efficiency.
""" """
# Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
# single stream blocks are always fused)
if not self._supports_qkv_fusion:
logger.debug(
f"{self.__class__.__name__} does not support fusing QKV projections, so `fuse_projections` will no-op."
)
return
# Skip if already fused # Skip if already fused
if getattr(self, "fused_projections", False): if getattr(self, "fused_projections", False):
return return
...@@ -307,6 +316,11 @@ class AttentionModuleMixin: ...@@ -307,6 +316,11 @@ class AttentionModuleMixin:
""" """
Unfuse the query, key, and value projections back to separate projections. Unfuse the query, key, and value projections back to separate projections.
""" """
# Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
# single stream blocks are always fused)
if not self._supports_qkv_fusion:
return
# Skip if not fused # Skip if not fused
if not getattr(self, "fused_projections", False): if not getattr(self, "fused_projections", False):
return return
......
...@@ -4,6 +4,7 @@ from .autoencoder_kl import AutoencoderKL ...@@ -4,6 +4,7 @@ from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_cosmos import AutoencoderKLCosmos from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_flux2 import AutoencoderKLFlux2
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
......
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
class AutoencoderKLFlux2(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
mid_block_add_attention (`bool`, *optional*, default to `True`):
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
mid_block will only have resnet blocks
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = (
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
),
up_block_types: Tuple[str, ...] = (
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
),
block_out_channels: Tuple[int, ...] = (
128,
256,
512,
512,
),
layers_per_block: int = 2,
act_fn: str = "silu",
latent_channels: int = 32,
norm_num_groups: int = 32,
sample_size: int = 1024, # YiYi notes: not sure
force_upcast: bool = True,
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
mid_block_add_attention: bool = True,
batch_norm_eps: float = 1e-4,
batch_norm_momentum: float = 0.1,
patch_size: Tuple[int, int] = (2, 2),
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
mid_block_add_attention=mid_block_add_attention,
)
# pass init params to Decoder
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
mid_block_add_attention=mid_block_add_attention,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
self.bn = nn.BatchNorm2d(
math.prod(patch_size) * latent_channels,
eps=batch_norm_eps,
momentum=batch_norm_momentum,
affine=False,
track_running_stats=True,
)
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)
enc = self.encoder(x)
if self.quant_conv is not None:
enc = self.quant_conv(enc)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
if self.config.use_quant_conv:
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
enc = torch.cat(result_rows, dim=2)
return enc
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
deprecation_message = (
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
)
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
if self.config.use_quant_conv:
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
if self.config.use_post_quant_conv:
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment