README.md 3.18 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
## Textual Inversion fine-tuning example

[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.

6
7
8
9
10
11
12
13
14
## Running on Colab 

Colab for training 
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)

Colab for inference
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)

## Running locally 
Suraj Patil's avatar
Suraj Patil committed
15
16
### Installing the dependencies

17
Before running the scripts, make sure to install the library's training dependencies:
Suraj Patil's avatar
Suraj Patil committed
18
19

```bash
20
pip install diffusers"[training]" accelerate "transformers>=4.21.0"
Suraj Patil's avatar
Suraj Patil committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```


### Cat toy example

You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. 

You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).

36
Run the following command to authenticate your token
Suraj Patil's avatar
Suraj Patil committed
37
38
39
40
41

```bash
huggingface-cli login
```

42
If you have already cloned the repo, then you won't need to go through these steps. 
Suraj Patil's avatar
Suraj Patil committed
43
44
45
46
47
48
49
50

<br>

Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.

And launch the training using

```bash
apolinario's avatar
apolinario committed
51
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
Suraj Patil's avatar
Suraj Patil committed
52
53
54
export DATA_DIR="path-to-dir-containing-images"

accelerate launch textual_inversion.py \
55
  --pretrained_model_name_or_path=$MODEL_NAME \
Suraj Patil's avatar
Suraj Patil committed
56
57
58
59
60
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<cat-toy>" --initializer_token="toy" \
  --resolution=512 \
  --train_batch_size=1 \
Suraj Patil's avatar
Suraj Patil committed
61
  --gradient_accumulation_steps=4 \
Suraj Patil's avatar
Suraj Patil committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 --scale_lr \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --output_dir="textual_inversion_cat"
```

A full training run takes ~1 hour on one V100 GPU.


### Inference

Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.

```python
from diffusers import StableDiffusionPipeline

model_id = "path-to-your-trained-model"
Kashif Rasul's avatar
Kashif Rasul committed
80
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
Suraj Patil's avatar
Suraj Patil committed
81
82
83

prompt = "A <cat-toy> backpack"

84
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
Suraj Patil's avatar
Suraj Patil committed
85
86

image.save("cat-backpack.png")
Suraj Patil's avatar
Suraj Patil committed
87
```