README.md 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Kandinsky2.2 text-to-image fine-tuning

Kandinsky 2.2 includes a prior pipeline that generates image embeddings from text prompts, and a decoder pipeline that generates the output image based on the image embeddings. We provide `train_text_to_image_prior.py` and `train_text_to_image_decoder.py` scripts to show you how to fine-tune the Kandinsky prior and decoder models separately based on your own dataset. To achieve the best results, you should fine-tune **_both_** your prior and decoder models.

___Note___:

___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___


## Running locally with PyTorch

Before running the scripts, make sure to install the library's training dependencies:

**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```

Then cd in the example folder  and run
```bash
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```
For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the --push_to_hub flag.

___

Tolga Cangöz's avatar
Tolga Cangöz committed
37
### Naruto example
38
39
40
41
42
43
44
45
46

For all our examples, we will directly store the trained weights on the Hub, so we need to be logged in and add the `--push_to_hub` flag. In order to do that, you have to be a registered user on the 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to the [User Access Tokens](https://huggingface.co/docs/hub/security-tokens) guide.

Run the following command to authenticate your token

```bash
huggingface-cli login
```

Tolga Cangöz's avatar
Tolga Cangöz committed
47
We also use [Weights and Biases](https://docs.wandb.ai/quickstart) logging by default, because it is really useful to monitor the training progress by regularly generating sample images during training. To install wandb, run
48
49
50
51
52

```bash
pip install wandb
```

Tolga Cangöz's avatar
Tolga Cangöz committed
53
To disable wandb logging, remove the `--report_to=="wandb"` and `--validation_prompts="A robot naruto, 4k photo"` flags from below examples
54
55
56
57
58
59

#### Fine-tune decoder
<br>

<!-- accelerate_snippet_start -->
```bash
60
export DATASET_NAME="lambdalabs/naruto-blip-captions"
61
62
63
64
65
66
67
68
69
70
71
72

accelerate launch --mixed_precision="fp16"  train_text_to_image_decoder.py \
  --dataset_name=$DATASET_NAME \
  --resolution=768 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --checkpoints_total_limit=3 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
Tolga Cangöz's avatar
Tolga Cangöz committed
73
  --validation_prompts="A robot naruto, 4k photo" \
74
75
  --report_to="wandb" \
  --push_to_hub \
Tolga Cangöz's avatar
Tolga Cangöz committed
76
  --output_dir="kandi2-decoder-naruto-model"
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
```
<!-- accelerate_snippet_end -->


To train on your own training files, prepare the dataset according to the format required by `datasets`. You can find the instructions for how to do that in the [ImageFolder with metadata](https://huggingface.co/docs/datasets/en/image_load#imagefolder-with-metadata) guide.
If you wish to use custom loading logic, you should modify the script and we have left pointers for that in the training script.

```bash
export TRAIN_DIR="path_to_your_dataset"

accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
  --train_data_dir=$TRAIN_DIR \
  --resolution=768 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --checkpoints_total_limit=3 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
Tolga Cangöz's avatar
Tolga Cangöz committed
98
  --validation_prompts="A robot naruto, 4k photo" \
99
100
  --report_to="wandb" \
  --push_to_hub \
Tolga Cangöz's avatar
Tolga Cangöz committed
101
  --output_dir="kandi22-decoder-naruto-model"
102
103
104
```


Tolga Cangöz's avatar
Tolga Cangöz committed
105
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `kandi22-decoder-naruto-model`. To load the fine-tuned model for inference just pass that path to `AutoPipelineForText2Image`
106
107
108
109
110
111
112
113

```python
from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained(output_dir, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

Tolga Cangöz's avatar
Tolga Cangöz committed
114
prompt='A robot naruto, 4k photo'
115
images = pipe(prompt=prompt).images
Tolga Cangöz's avatar
Tolga Cangöz committed
116
images[0].save("robot-naruto.png")
117
118
119
120
121
122
123
124
125
126
127
128
129
```

Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
```python
from diffusers import AutoPipelineForText2Image, UNet2DConditionModel

model_path = "path_to_saved_model"

unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet")

pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", unet=unet, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

Tolga Cangöz's avatar
Tolga Cangöz committed
130
131
image = pipe(prompt="A robot naruto, 4k photo").images[0]
image.save("robot-naruto.png")
132
133
```

Tolga Cangöz's avatar
Tolga Cangöz committed
134
#### Fine-tune prior
135
136
137
138
139
140
141

You can fine-tune the Kandinsky prior model with `train_text_to_image_prior.py` script. Note that we currently do not support `--gradient_checkpointing` for prior model fine-tuning.

<br>

<!-- accelerate_snippet_start -->
```bash
142
export DATASET_NAME="lambdalabs/naruto-blip-captions"
143
144
145
146
147
148
149
150
151
152
153

accelerate launch --mixed_precision="fp16"  train_text_to_image_prior.py \
  --dataset_name=$DATASET_NAME \
  --resolution=768 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --checkpoints_total_limit=3 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
Tolga Cangöz's avatar
Tolga Cangöz committed
154
  --validation_prompts="A robot naruto, 4k photo" \
155
156
  --report_to="wandb" \
  --push_to_hub \
Tolga Cangöz's avatar
Tolga Cangöz committed
157
  --output_dir="kandi2-prior-naruto-model"
158
159
160
161
```
<!-- accelerate_snippet_end -->


Tolga Cangöz's avatar
Tolga Cangöz committed
162
To perform inference with the fine-tuned prior model, you will need to first create a prior pipeline by passing the `output_dir` to `DiffusionPipeline`. Then create a `KandinskyV22CombinedPipeline` from a pretrained or fine-tuned decoder checkpoint along with all the modules of the prior pipeline you just created.
163
164
165
166
167
168
169
170
171
172

```python
from diffusers import AutoPipelineForText2Image, DiffusionPipeline
import torch

pipe_prior = DiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16)
prior_components = {"prior_" + k: v for k,v in pipe_prior.components.items()}
pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", **prior_components, torch_dtype=torch.float16)

pipe.enable_model_cpu_offload()
Tolga Cangöz's avatar
Tolga Cangöz committed
173
prompt='A robot naruto, 4k photo'
174
175
176
177
images = pipe(prompt=prompt, negative_prompt=negative_prompt).images
images[0]
```

Tolga Cangöz's avatar
Tolga Cangöz committed
178
If you want to use a fine-tuned decoder checkpoint along with your fine-tuned prior checkpoint, you can simply replace the "kandinsky-community/kandinsky-2-2-decoder" in above code with your custom model repo name. Note that in order to be able to create a `KandinskyV22CombinedPipeline`, your model repository need to have a prior tag. If you have created your model repo using our training script, the prior tag is automatically included.
179
180
181
182
183
184
185

#### Training with multiple GPUs

`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
for running distributed training with `accelerate`. Here is an example command:

```bash
186
export DATASET_NAME="lambdalabs/naruto-blip-captions"
187
188
189
190
191
192
193
194
195
196
197
198

accelerate launch --mixed_precision="fp16" --multi_gpu  train_text_to_image_decoder.py \
  --dataset_name=$DATASET_NAME \
  --resolution=768 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --checkpoints_total_limit=3 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
Tolga Cangöz's avatar
Tolga Cangöz committed
199
  --validation_prompts="A robot naruto, 4k photo" \
200
201
  --report_to="wandb" \
  --push_to_hub \
Tolga Cangöz's avatar
Tolga Cangöz committed
202
  --output_dir="kandi2-decoder-naruto-model"
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
```


#### Training with Min-SNR weighting

We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps achieve faster convergence
by rebalancing the loss. Enable the `--snr_gamma` argument and set it to the recommended
value of 5.0.


## Training with LoRA

Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.

In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:

- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.

[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.

With LoRA, it's possible to fine-tune Kandinsky 2.2 on a custom image-caption pair dataset
on consumer GPUs like Tesla T4, Tesla V100.

### Training

Tolga Cangöz's avatar
Tolga Cangöz committed
230
First, you need to set up your development environment as explained in the [installation](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).
231
232


Tolga Cangöz's avatar
Tolga Cangöz committed
233
#### Train decoder
234
235

```bash
236
export DATASET_NAME="lambdalabs/naruto-blip-captions"
237
238
239
240
241
242
243
244
245
246

accelerate launch --mixed_precision="fp16" train_text_to_image_decoder_lora.py \
  --dataset_name=$DATASET_NAME --caption_column="text" \
  --resolution=768 \
  --train_batch_size=1 \
  --num_train_epochs=100 --checkpointing_steps=5000 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --rank=4 \
  --gradient_checkpointing \
Tolga Cangöz's avatar
Tolga Cangöz committed
247
  --output_dir="kandi22-decoder-naruto-lora" \
248
249
250
251
252
253
254
  --validation_prompt="cute dragon creature" --report_to="wandb" \
  --push_to_hub \
```

#### Train prior

```bash
255
export DATASET_NAME="lambdalabs/naruto-blip-captions"
256
257
258
259
260
261
262
263
264

accelerate launch --mixed_precision="fp16" train_text_to_image_prior_lora.py \
  --dataset_name=$DATASET_NAME --caption_column="text" \
  --resolution=768 \
  --train_batch_size=1 \
  --num_train_epochs=100 --checkpointing_steps=5000 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --rank=4 \
Tolga Cangöz's avatar
Tolga Cangöz committed
265
  --output_dir="kandi22-prior-naruto-lora" \
266
267
268
269
270
271
272
273
274
275
276
  --validation_prompt="cute dragon creature" --report_to="wandb" \
  --push_to_hub \
```

**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run above scripts in consumer GPUs like T4 or V100.___**


### Inference

#### Inference using fine-tuned LoRA checkpoint for decoder

Tolga Cangöz's avatar
Tolga Cangöz committed
277
Once you have trained a Kandinsky decoder model using the above command, inference can be done with the `AutoPipelineForText2Image` after loading the trained LoRA weights.  You need to pass the `output_dir` for loading the LoRA weights, which in this case is `kandi22-decoder-naruto-lora`.
278
279
280
281
282
283
284
285
286
287


```python
from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
pipe.unet.load_attn_procs(output_dir)
pipe.enable_model_cpu_offload()

Tolga Cangöz's avatar
Tolga Cangöz committed
288
prompt='A robot naruto, 4k photo'
289
image = pipe(prompt=prompt).images[0]
Tolga Cangöz's avatar
Tolga Cangöz committed
290
image.save("robot_naruto.png")
291
292
293
294
295
296
297
298
299
300
301
302
```

#### Inference using fine-tuned LoRA checkpoint for prior

```python
from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
pipe.prior_prior.load_attn_procs(output_dir)
pipe.enable_model_cpu_offload()

Tolga Cangöz's avatar
Tolga Cangöz committed
303
prompt='A robot naruto, 4k photo'
304
image = pipe(prompt=prompt).images[0]
Tolga Cangöz's avatar
Tolga Cangöz committed
305
image.save("robot_naruto.png")
306
307
308
309
310
311
312
313
314
315
316
317
image
```

### Training with xFormers:

You can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script.

xFormers training is not available for fine-tuning the prior model.

**Note**:

According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment.