Unverified Commit 52f2128d authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

update readme for flax examples (#1026)

parent fbcc3833
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion.
## Running locally ## Running locally with PyTorch
### Installing the dependencies ### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies: Before running the scripts, make sure to install the library's training dependencies:
...@@ -58,24 +58,6 @@ accelerate launch train_dreambooth.py \ ...@@ -58,24 +58,6 @@ accelerate launch train_dreambooth.py \
--max_train_steps=400 --max_train_steps=400
``` ```
Or use the Flax implementation if you need a speedup
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
python train_dreambooth_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=5e-6 \
--max_train_steps=400
```
### Training with prior-preservation loss ### Training with prior-preservation loss
Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
...@@ -105,28 +87,6 @@ accelerate launch train_dreambooth.py \ ...@@ -105,28 +87,6 @@ accelerate launch train_dreambooth.py \
--max_train_steps=800 --max_train_steps=800
``` ```
Or use the Flax implementation if you need a speedup
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
python train_dreambooth_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=5e-6 \
--num_class_images=200 \
--max_train_steps=800
```
### Training on a 16GB GPU: ### Training on a 16GB GPU:
...@@ -234,7 +194,58 @@ accelerate launch train_dreambooth.py \ ...@@ -234,7 +194,58 @@ accelerate launch train_dreambooth.py \
--max_train_steps=800 --max_train_steps=800
``` ```
Or use the Flax implementation if you need a speedup ### Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
```python
from diffusers import StableDiffusionPipeline
import torch
model_id = "path-to-your-trained-model"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
prompt = "A photo of sks dog in a bucket"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("dog-bucket.png")
```
## Running with Flax/JAX
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___
Before running the scripts, make sure to install the library's training dependencies:
```bash
pip install -U -r requirements_flax.txt
```
### Training without prior preservation loss
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
python train_dreambooth_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=5e-6 \
--max_train_steps=400
```
### Training with prior preservation loss
```bash ```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
...@@ -244,7 +255,6 @@ export OUTPUT_DIR="path-to-save-model" ...@@ -244,7 +255,6 @@ export OUTPUT_DIR="path-to-save-model"
python train_dreambooth_flax.py \ python train_dreambooth_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--train_text_encoder \
--instance_data_dir=$INSTANCE_DIR \ --instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \ --class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
...@@ -253,24 +263,32 @@ python train_dreambooth_flax.py \ ...@@ -253,24 +263,32 @@ python train_dreambooth_flax.py \
--class_prompt="a photo of dog" \ --class_prompt="a photo of dog" \
--resolution=512 \ --resolution=512 \
--train_batch_size=1 \ --train_batch_size=1 \
--learning_rate=2e-6 \ --learning_rate=5e-6 \
--num_class_images=200 \ --num_class_images=200 \
--max_train_steps=800 --max_train_steps=800
``` ```
## Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
```python
from diffusers import StableDiffusionPipeline
import torch
model_id = "path-to-your-trained-model" ### Fine-tune text encoder with the UNet.
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
prompt = "A photo of sks dog in a bucket" ```bash
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
image.save("dog-bucket.png") python train_dreambooth_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_text_encoder \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=2e-6 \
--num_class_images=200 \
--max_train_steps=800
``` ```
\ No newline at end of file
diffusers>==0.5.1
transformers>=4.21.0
flax
optax
torch
torchvision
ftfy
tensorboard
modelcards
\ No newline at end of file
...@@ -7,7 +7,7 @@ ___Note___: ...@@ -7,7 +7,7 @@ ___Note___:
___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___ ___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
## Running locally ## Running locally with PyTorch
### Installing the dependencies ### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies: Before running the scripts, make sure to install the library's training dependencies:
...@@ -62,24 +62,6 @@ accelerate launch train_text_to_image.py \ ...@@ -62,24 +62,6 @@ accelerate launch train_text_to_image.py \
--output_dir="sd-pokemon-model" --output_dir="sd-pokemon-model"
``` ```
Or use the Flax implementation if you need a speedup
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export dataset_name="lambdalabs/pokemon-blip-captions"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-pokemon-model"
```
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
...@@ -104,15 +86,43 @@ accelerate launch train_text_to_image.py \ ...@@ -104,15 +86,43 @@ accelerate launch train_text_to_image.py \
--output_dir="sd-pokemon-model" --output_dir="sd-pokemon-model"
``` ```
Or use the Flax implementation if you need a speedup
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
```python
from diffusers import StableDiffusionPipeline
model_path = "path_to_saved_model"
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")
image = pipe(prompt="yoda").images[0]
image.save("yoda-pokemon.png")
```
## Training with Flax/JAX
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___
Before running the scripts, make sure to install the library's training dependencies:
```bash
pip install -U -r requirements_flax.txt
```
```bash ```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export TRAIN_DIR="path_to_your_dataset" export dataset_name="lambdalabs/pokemon-blip-captions"
python train_text_to_image_flax.py \ python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \ --dataset_name=$dataset_name \
--resolution=512 --center_crop --random_flip \ --resolution=512 --center_crop --random_flip \
--train_batch_size=1 \ --train_batch_size=1 \
--mixed_precision="fp16" \ --mixed_precision="fp16" \
...@@ -122,16 +132,22 @@ python train_text_to_image_flax.py \ ...@@ -122,16 +132,22 @@ python train_text_to_image_flax.py \
--output_dir="sd-pokemon-model" --output_dir="sd-pokemon-model"
``` ```
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
```python ```bash
from diffusers import StableDiffusionPipeline export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export TRAIN_DIR="path_to_your_dataset"
model_path = "path_to_saved_model"
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")
image = pipe(prompt="yoda").images[0] python train_text_to_image_flax.py \
image.save("yoda-pokemon.png") --pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-pokemon-model"
``` ```
diffusers>==0.5.1
transformers>=4.21.0
flax
optax
torch
torchvision
ftfy
tensorboard
modelcards
\ No newline at end of file
...@@ -11,7 +11,7 @@ Colab for training ...@@ -11,7 +11,7 @@ Colab for training
Colab for inference Colab for inference
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
## Running locally ## Running locally with PyTorch
### Installing the dependencies ### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies: Before running the scripts, make sure to install the library's training dependencies:
...@@ -68,7 +68,33 @@ accelerate launch textual_inversion.py \ ...@@ -68,7 +68,33 @@ accelerate launch textual_inversion.py \
A full training run takes ~1 hour on one V100 GPU. A full training run takes ~1 hour on one V100 GPU.
If you want to speed it up even more, Flax implementation is available: ### Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
```python
from diffusers import StableDiffusionPipeline
model_id = "path-to-your-trained-model"
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
prompt = "A <cat-toy> backpack"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
```
## Training with Flax/JAX
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
Before running the scripts, make sure to install the library's training dependencies:
```bash
pip install -U -r requirements_flax.txt
```
```bash ```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
...@@ -86,20 +112,3 @@ python textual_inversion_flax.py \ ...@@ -86,20 +112,3 @@ python textual_inversion_flax.py \
--output_dir="textual_inversion_cat" --output_dir="textual_inversion_cat"
``` ```
It should be at least 70% faster than the PyTorch script with the same configuration. It should be at least 70% faster than the PyTorch script with the same configuration.
\ No newline at end of file
### Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
```python
from diffusers import StableDiffusionPipeline
model_id = "path-to-your-trained-model"
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
prompt = "A <cat-toy> backpack"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
```
diffusers>==0.5.1
transformers>=4.21.0
flax
optax
torch
torchvision
ftfy
tensorboard
modelcards
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment