"vscode:/vscode.git/clone" did not exist on "963224f50b28ac2996610e38127f4b569c8c36da"
README.md 4.95 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
## Textual Inversion fine-tuning example

[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.

6
7
8
9
10
11
12
13
## Running on Colab 

Colab for training 
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)

Colab for inference
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)

14
## Running locally with PyTorch
Suraj Patil's avatar
Suraj Patil committed
15
16
### Installing the dependencies

17
Before running the scripts, make sure to install the library's training dependencies:
Suraj Patil's avatar
Suraj Patil committed
18

19
20
21
22
23
24
25
26
27
28
**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```

Then cd in the example folder  and run
Suraj Patil's avatar
Suraj Patil committed
29
```bash
30
pip install -r requirements.txt
Suraj Patil's avatar
Suraj Patil committed
31
32
33
34
35
36
37
38
39
40
41
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```


### Cat toy example

42
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree. 
Suraj Patil's avatar
Suraj Patil committed
43
44
45

You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).

46
Run the following command to authenticate your token
Suraj Patil's avatar
Suraj Patil committed
47
48
49
50
51

```bash
huggingface-cli login
```

52
If you have already cloned the repo, then you won't need to go through these steps. 
Suraj Patil's avatar
Suraj Patil committed
53
54
55
56
57
58
59

<br>

Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.

And launch the training using

60
61
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**

Suraj Patil's avatar
Suraj Patil committed
62
```bash
apolinario's avatar
apolinario committed
63
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
Suraj Patil's avatar
Suraj Patil committed
64
65
66
export DATA_DIR="path-to-dir-containing-images"

accelerate launch textual_inversion.py \
67
  --pretrained_model_name_or_path=$MODEL_NAME \
Suraj Patil's avatar
Suraj Patil committed
68
69
70
71
72
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<cat-toy>" --initializer_token="toy" \
  --resolution=512 \
  --train_batch_size=1 \
Suraj Patil's avatar
Suraj Patil committed
73
  --gradient_accumulation_steps=4 \
Suraj Patil's avatar
Suraj Patil committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 --scale_lr \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --output_dir="textual_inversion_cat"
```

A full training run takes ~1 hour on one V100 GPU.

### Inference

Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.

```python
from diffusers import StableDiffusionPipeline

model_id = "path-to-your-trained-model"
Kashif Rasul's avatar
Kashif Rasul committed
91
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
Suraj Patil's avatar
Suraj Patil committed
92
93
94

prompt = "A <cat-toy> backpack"

95
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
Suraj Patil's avatar
Suraj Patil committed
96
97

image.save("cat-backpack.png")
Suraj Patil's avatar
Suraj Patil committed
98
```
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125


## Training with Flax/JAX

For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.

Before running the scripts, make sure to install the library's training dependencies:

```bash
pip install -U -r requirements_flax.txt
```

```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export DATA_DIR="path-to-dir-containing-images"

python textual_inversion_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<cat-toy>" --initializer_token="toy" \
  --resolution=512 \
  --train_batch_size=1 \
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 --scale_lr \
  --output_dir="textual_inversion_cat"
```
126
It should be at least 70% faster than the PyTorch script with the same configuration.
127
128
129

### Training with xformers:
You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.