text_inversion.mdx 12.7 KB
Newer Older
1
 <!--Copyright 2023 The HuggingFace Team. All rights reserved.
Nathan Lambert's avatar
Nathan Lambert committed
2
3
4
5
6
7
8
9
10
11
12

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

Patrick von Platen's avatar
Patrick von Platen committed
13
14


15
# Textual Inversion
Patrick von Platen's avatar
Patrick von Platen committed
16

17
[[open-in-colab]]
Patrick von Platen's avatar
Patrick von Platen committed
18

19
[Textual Inversion](https://arxiv.org/abs/2208.01618) is a technique for capturing novel concepts from a small number of example images. While the technique was originally demonstrated with a [latent diffusion model](https://github.com/CompVis/latent-diffusion), it has since been applied to other model variants like [Stable Diffusion](https://huggingface.co/docs/diffusers/main/en/conceptual/stable_diffusion). The learned concepts can be used to better control the images generated from text-to-image pipelines. It learns new "words" in the text encoder's embedding space, which are used within text prompts for personalized image generation.
20

21
![Textual Inversion example](https://textual-inversion.github.io/static/images/editing/colorful_teapot.JPG)
22
<small>By using just 3-5 images you can teach new concepts to a model such as Stable Diffusion for personalized image generation <a href="https://github.com/rinongal/textual_inversion">(image source)</a>.</small>
23

24
This guide will show you how to train a [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) model with Textual Inversion. All the training scripts for Textual Inversion used in this guide can be found [here](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) if you're interested in taking a closer look at how things work under the hood.
25

26
<Tip>
27

28
There is a community-created collection of trained Textual Inversion models in the [Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library) which are readily available for inference. Over time, this'll hopefully grow into a useful resource as more concepts are added!
29

30
</Tip>
31

32
Before you begin, make sure you install the library's training dependencies:
33

34
35
36
```bash
pip install diffusers accelerate transformers
```
37

38
After all the dependencies have been set up, initialize a [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
39

40
41
42
```bash
accelerate config
```
43

44
To setup a default 🤗 Accelerate environment without choosing any configurations:
45

46
47
48
```bash
accelerate config default
```
49

50
Or if your environment doesn't support an interactive shell like a notebook, you can use:
51

52
53
```bash
from accelerate.utils import write_basic_config
54

55
56
write_basic_config()
```
57

58
Finally, you try and [install xFormers](https://huggingface.co/docs/diffusers/main/en/training/optimization/xformers) to reduce your memory footprint with xFormers memory-efficient attention. Once you have xFormers installed, add the `--enable_xformers_memory_efficient_attention` argument to the training script. xFormers is not supported for Flax.
59

60
## Upload model to Hub
61

62
If you want to store your model on the Hub, add the following argument to the training script:
63
64

```bash
65
--push_to_hub
Patrick von Platen's avatar
Patrick von Platen committed
66
67
```

68
## Save and load checkpoints
Patrick von Platen's avatar
Patrick von Platen committed
69

70
It is often a good idea to regularly save checkpoints of your model during training. This way, you can resume training from a saved checkpoint if your training is interrupted for any reason. To save a checkpoint, pass the following argument to the training script to save the full training state in a subfolder in `output_dir` every 500 steps:
Patrick von Platen's avatar
Patrick von Platen committed
71

72
73
74
```bash
--checkpointing_steps=500
```
Patrick von Platen's avatar
Patrick von Platen committed
75

76
To resume training from a saved checkpoint, pass the following argument to the training script and the specific checkpoint you'd like to resume from:
Patrick von Platen's avatar
Patrick von Platen committed
77

78
```bash
79
--resume_from_checkpoint="checkpoint-1500"
80
81
```

82
83
## Finetuning

Steven Liu's avatar
Steven Liu committed
84
For your training dataset, download these [images of a cat toy](https://huggingface.co/datasets/diffusers/cat_toy_example) and store them in a directory. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide.
85

86
87
88
89
90
91
92
93
94
```py
from huggingface_hub import snapshot_download

local_dir = "./cat"
snapshot_download(
    "diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes"
)
```

Steven Liu's avatar
Steven Liu committed
95
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument, and the `DATA_DIR` environment variable to the path of the directory containing the images. 
96

Steven Liu's avatar
Steven Liu committed
97
Now you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py). The script creates and saves the following files to your repository: `learned_embeds.bin`, `token_identifier.txt`, and `type_of_concept.txt`.
98

99
<Tip>
100

101
💡 A full training run takes ~1 hour on one V100 GPU. While you're waiting for the training to complete, feel free to check out [how Textual Inversion works](#how-it-works) in the section below if you're curious!
102

103
104
105
106
</Tip>

<frameworkcontent>
<pt>
107
```bash
apolinario's avatar
apolinario committed
108
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
109
export DATA_DIR="./cat"
110
111

accelerate launch textual_inversion.py \
112
  --pretrained_model_name_or_path=$MODEL_NAME \
113
114
115
116
117
118
119
120
121
122
123
124
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<cat-toy>" --initializer_token="toy" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 --scale_lr \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --output_dir="textual_inversion_cat"
```
125
126
127
128
129
130
131
132
133
134
135
136

<Tip>

💡 If you want to increase the trainable capacity, you can associate your placeholder token, *e.g.* `<cat-toy>` to 
multiple embedding vectors. This can help the model to better capture the style of more (complex) images. 
To enable training multiple embedding vectors, simply pass:

```bash
--num_vectors=5
```

</Tip>
137
138
139
</pt>
<jax>
If you have access to TPUs, try out the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py) to train even faster (this'll also work for GPUs). With the same configuration settings, the Flax training script should be at least 70% faster than the PyTorch training script! ⚡️
140

141
Before you begin, make sure you install the Flax specific dependencies:
142

143
144
145
```bash
pip install -U -r requirements_flax.txt
```
146

Steven Liu's avatar
Steven Liu committed
147
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument.
148

149
Then you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py):
150

151
152
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
153
export DATA_DIR="./cat"
154

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
python textual_inversion_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<cat-toy>" --initializer_token="toy" \
  --resolution=512 \
  --train_batch_size=1 \
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 --scale_lr \
  --output_dir="textual_inversion_cat"
```
</jax>
</frameworkcontent>

### Intermediate logging

If you're interested in following along with your model training progress, you can save the generated images from the training process. Add the following arguments to the training script to enable intermediate logging:

- `validation_prompt`, the prompt used to generate samples (this is set to `None` by default and intermediate logging is disabled)
- `num_validation_images`, the number of sample images to generate
- `validation_steps`, the number of steps before generating `num_validation_images` from the `validation_prompt`

```bash
--validation_prompt="A <cat-toy> backpack"
--num_validation_images=4
--validation_steps=100
```

## Inference

185
186
187
188
Once you have trained a model, you can use it for inference with the [`StableDiffusionPipeline`].

The textual inversion script will by default only save the textual inversion embedding vector(s) that have 
been added to the text encoder embedding matrix and consequently been trained.
189
190
191

<frameworkcontent>
<pt>
192
193
194
195
196
197
198
199
200
201
<Tip>

💡 The community has created a large library of different textual inversion embedding vectors, called [sd-concepts-library](https://huggingface.co/sd-concepts-library).
Instead of training textual inversion embeddings from scratch you can also see whether a fitting textual inversion embedding has already been added to the libary.

</Tip>

To load the textual inversion embeddings you first need to load the base model that was used when training 
your textual inversion embedding vectors. Here we assume that [`runwayml/stable-diffusion-v1-5`](runwayml/stable-diffusion-v1-5)
was used as a base model so we load it first:
202
203
```python
from diffusers import StableDiffusionPipeline
204
import torch
205

206
model_id = "runwayml/stable-diffusion-v1-5"
Kashif Rasul's avatar
Kashif Rasul committed
207
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
208
```
209

210
211
212
213
214
Next, we need to load the textual inversion embedding vector which can be done via the [`TextualInversionLoaderMixin.load_textual_inversion`]
function. Here we'll load the embeddings of the "<cat-toy>" example from before.
```python
pipe.load_textual_inversion("sd-concepts-library/cat-toy")
```
215

216
Now we can run the pipeline making sure that the placeholder token `<cat-toy>` is used in our prompt.
217

218
219
220
221
```python
prompt = "A <cat-toy> backpack"

image = pipe(prompt, num_inference_steps=50).images[0]
222
223
image.save("cat-backpack.png")
```
224
225
226
227
228
229
230
231
232

The function [`TextualInversionLoaderMixin.load_textual_inversion`] can not only 
load textual embedding vectors saved in Diffusers' format, but also embedding vectors
saved in [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) format.
To do so, you can first download an embedding vector from [civitAI](https://civitai.com/models/3036?modelVersionId=8387)
and then load it locally:
```python
pipe.load_textual_inversion("./charturnerv2.pt")
```
233
234
</pt>
<jax>
235
236
237
238
239
Currently there is no `load_textual_inversion` function for Flax so one has to make sure the textual inversion
embedding vector is saved as part of the model after training.

The model can then be run just like any other Flax model:

240
241
242
243
244
245
246
247
```python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline

model_path = "path-to-your-trained-model"
248
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

prompt = "A <cat-toy> backpack"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
image.save("cat-backpack.png")
```
</jax>
</frameworkcontent>

## How it works

![Diagram from the paper showing overview](https://textual-inversion.github.io/static/images/training/training.JPG)
<small>Architecture overview from the Textual Inversion <a href="https://textual-inversion.github.io/">blog post.</a></small>

Usually, text prompts are tokenized into an embedding before being passed to a model, which is often a transformer. Textual Inversion does something similar, but it learns a new token embedding, `v*`, from a special token `S*` in the diagram above. The model output is used to condition the diffusion model, which helps the diffusion model understand the prompt and new concepts from just a few example images.

277
To do this, Textual Inversion uses a generator model and noisy versions of the training images. The generator tries to predict less noisy versions of the images, and the token embedding `v*` is optimized based on how well the generator does. If the token embedding successfully captures the new concept, it gives more useful information to the diffusion model and helps create clearer images with less noise. This optimization process typically occurs after several thousand steps of exposure to a variety of prompt and image variants.