"vscode:/vscode.git/clone" did not exist on "be0058bc0584c44ea749e332be1432909a7228c9"
lora.mdx 14.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

13
# Low-Rank Adaptation of Large Language Models (LoRA)
14

15
[[open-in-colab]]
16

17
<Tip warning={true}>
18

19
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also 
Steven Liu's avatar
Steven Liu committed
20
support fine-tuning the text encoder for DreamBooth with LoRA in a limited capacity. Fine-tuning the text encoder for DreamBooth generally yields better results, but it can increase compute usage.
21

22
</Tip>
23

24
[Low-Rank Adaptation of Large Language Models (LoRA)](https://arxiv.org/abs/2106.09685) is a training method that accelerates the training of large models while consuming less memory. It adds pairs of rank-decomposition weight matrices (called **update matrices**) to existing weights, and **only** trains those newly added weights. This has a couple of advantages:
25

26
27
28
29
- Previous pretrained weights are kept frozen so the model is not as prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
- LoRA matrices are generally added to the attention layers of the original model. 🧨 Diffusers provides the [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method to load the LoRA weights into a model's attention layers. You can control the extent to which the model is adapted toward new training images via a `scale` parameter. 
- The greater memory-efficiency allows you to run fine-tuning on consumer GPUs like the Tesla T4, RTX 3080 or even the RTX 2080 Ti! GPUs like the T4 are free and readily accessible in Kaggle or Google Colab notebooks.
30
31
32

<Tip>

33
34
💡 LoRA is not only limited to attention layers. The authors found that amending
the attention layers of a language model is sufficient to obtain good downstream performance with great efficiency. This is why it's common to just add the LoRA weights to the attention layers of a model. Check out the [Using LoRA for efficient Stable Diffusion fine-tuning](https://huggingface.co/blog/lora) blog for more information about how LoRA works!
35
36
37

</Tip>

38
39
40
41
42
43
44
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. 🧨 Diffusers now supports finetuning with LoRA for [text-to-image generation](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) and [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora). This guide will show you how to do both.

If you'd like to store or share your model with the community, login to your Hugging Face account (create [one](hf.co/join) if you don't have one already):

```bash
huggingface-cli login
```
45

46
## Text-to-image
47

48
Finetuning a model like Stable Diffusion, which has billions of parameters, can be slow and difficult. With LoRA, it is much easier and faster to finetune a diffusion model. It can run on hardware with as little as 11GB of GPU RAM without resorting to tricks such as 8-bit optimizers.
49

50
### Training[[text-to-image-training]]
51

52
Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate your own Pokémon.
53

Steven Liu's avatar
Steven Liu committed
54
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument. You'll also need to set the `DATASET_NAME` environment variable to the name of the dataset you want to train on. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide.
55
56

The `OUTPUT_DIR` and `HUB_MODEL_ID` variables are optional and specify where to save the model to on the Hub:
57
58
59

```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
60
61
62
63
export OUTPUT_DIR="/sddata/finetune/lora/pokemon"
export HUB_MODEL_ID="pokemon-lora"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
```
64

65
66
67
68
69
70
There are some flags to be aware of before you start training:

* `--push_to_hub` stores the trained LoRA embeddings on the Hub.
* `--report_to=wandb` reports and logs the training results to your Weights & Biases dashboard (as an example, take a look at this [report](https://wandb.ai/pcuenq/text2image-fine-tune/runs/b4k1w0tn?workspace=user-pcuenq)).
* `--learning_rate=1e-04`, you can afford to use a higher learning rate than you normally would with LoRA.

Steven Liu's avatar
Steven Liu committed
71
Now you're ready to launch the training (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)). Training takes about 5 hours on a 2080 Ti GPU with 11GB of RAM, and it'll create and save model checkpoints and the `pytorch_lora_weights` in your repository.
72
73
74
75
76
77
78

```bash
accelerate launch --mixed_precision="fp16"  train_text_to_image_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$DATASET_NAME \
  --dataloader_num_workers=8 \
  --resolution=512 --center_crop --random_flip \
79
  --train_batch_size=1 \
80
81
82
83
84
85
86
87
88
89
90
91
  --gradient_accumulation_steps=4 \
  --max_train_steps=15000 \
  --learning_rate=1e-04 \
  --max_grad_norm=1 \
  --lr_scheduler="cosine" --lr_warmup_steps=0 \
  --output_dir=${OUTPUT_DIR} \
  --push_to_hub \
  --hub_model_id=${HUB_MODEL_ID} \
  --report_to=wandb \
  --checkpointing_steps=500 \
  --validation_prompt="A pokemon with blue eyes." \
  --seed=1337
92
93
```

94
### Inference[[text-to-image-inference]]
95

96
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`] and then the [`DPMSolverMultistepScheduler`]:
97

98
99
100
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
101

102
>>> model_base = "runwayml/stable-diffusion-v1-5"
103

104
105
106
>>> pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
```
107

108
Load the LoRA weights from your finetuned model *on top of the base model weights*, and then move the pipeline to a GPU for faster inference. When you merge the LoRA weights with the frozen pretrained model weights, you can optionally adjust how much of the weights to merge with the `scale` parameter:
109

110
<Tip>
111

112
💡 A `scale` value of `0` is the same as not using your LoRA weights and you're only using the base model weights, and a `scale` value of `1` means you're only using the fully finetuned LoRA weights. Values between `0` and `1` interpolates between the two weights.
113

114
</Tip>
115

116
```py
117
>>> pipe.unet.load_attn_procs(lora_model_path)
118
119
>>> pipe.to("cuda")
# use half the weights from the LoRA finetuned model and half the weights from the base model
120

121
122
123
124
>>> image = pipe(
...     "A pokemon with blue eyes.", num_inference_steps=25, guidance_scale=7.5, cross_attention_kwargs={"scale": 0.5}
... ).images[0]
# use the weights from the fully finetuned LoRA model
125

126
127
128
>>> image = pipe("A pokemon with blue eyes.", num_inference_steps=25, guidance_scale=7.5).images[0]
>>> image.save("blue_pokemon.png")
```
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
<Tip>

If you are loading the LoRA parameters from the Hub and if the Hub repository has
a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
you can do: 

```py 
from huggingface_hub.repocard import RepoCard

lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]

pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
...
```

</Tip>

149
## DreamBooth
150

151
[DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type).
152

153
<Tip>
154

155
💡 Take a look at the [Training Stable Diffusion with DreamBooth using 🧨 Diffusers](https://huggingface.co/blog/dreambooth) blog for an in-depth analysis of DreamBooth experiments and recommended settings.
156

157
</Tip>
158

159
### Training[[dreambooth-training]]
160

Steven Liu's avatar
Steven Liu committed
161
Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) with DreamBooth and LoRA with some 🐶 [dog images](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ). Download and save these images to a directory. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide.
162

Steven Liu's avatar
Steven Liu committed
163
To start, specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument. You'll also need to set `INSTANCE_DIR` to the path of the directory containing the images. 
164
165

The `OUTPUT_DIR` variables is optional and specifies where to save the model to on the Hub:
166

167
168
169
170
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
171
172
```

173
There are some flags to be aware of before you start training:
174

175
176
177
* `--push_to_hub` stores the trained LoRA embeddings on the Hub.
* `--report_to=wandb` reports and logs the training results to your Weights & Biases dashboard (as an example, take a look at this [report](https://wandb.ai/pcuenq/text2image-fine-tune/runs/b4k1w0tn?workspace=user-pcuenq)).
* `--learning_rate=1e-04`, you can afford to use a higher learning rate than you normally would with LoRA.
178

Steven Liu's avatar
Steven Liu committed
179
180
181
182
183
Now you're ready to launch the training (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py)). The script creates and saves model checkpoints and the `pytorch_lora_weights.bin` file in your repository.

It's also possible to additionally fine-tune the text encoder with LoRA. This, in most cases, leads
to better results with a slight increase in the compute. To allow fine-tuning the text encoder with LoRA,
specify the `--train_text_encoder` while launching the `train_dreambooth_lora.py` script.
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
```bash
accelerate launch train_dreambooth_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=100 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=50 \
  --seed="0" \
  --push_to_hub
Steven Liu's avatar
Steven Liu committed
204
``` 
205

206
### Inference[[dreambooth-inference]]
207

208
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]:
209

210
211
212
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline
213

214
>>> model_base = "runwayml/stable-diffusion-v1-5"
215

216
>>> pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
217
218
```

219
Load the LoRA weights from your finetuned DreamBooth model *on top of the base model weights*, and then move the pipeline to a GPU for faster inference. When you merge the LoRA weights with the frozen pretrained model weights, you can optionally adjust how much of the weights to merge with the `scale` parameter:
220

221
222
223
<Tip>

💡 A `scale` value of `0` is the same as not using your LoRA weights and you're only using the base model weights, and a `scale` value of `1` means you're only using the fully finetuned LoRA weights. Values between `0` and `1` interpolates between the two weights.
224

225
</Tip>
226

227
```py
228
>>> pipe.unet.load_attn_procs(lora_model_path)
229
230
231
232
233
234
235
236
237
238
239
240
241
>>> pipe.to("cuda")
# use half the weights from the LoRA finetuned model and half the weights from the base model

>>> image = pipe(
...     "A picture of a sks dog in a bucket.",
...     num_inference_steps=25,
...     guidance_scale=7.5,
...     cross_attention_kwargs={"scale": 0.5},
... ).images[0]
# use the weights from the fully finetuned LoRA model

>>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0]
>>> image.save("bucket-dog.png")
242
243
```

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
If you used `--train_text_encoder` during training, then use `pipe.load_lora_weights()` to load the LoRA
weights. For example:

```python
from huggingface_hub.repocard import RepoCard
from diffusers import StableDiffusionPipeline
import torch

lora_model_id = "sayakpaul/dreambooth-text-encoder-test"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]

pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.load_lora_weights(lora_model_id)
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
```

Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is preferred to [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because
[`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] can handle the following situations:
264
265
266
267
268
269
270
271

* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:

  ```py 
  pipe.load_lora_weights(lora_model_path)
  ```

* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).