Unverified Commit abe05822 authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

[Flax] Add finetune Stable Diffusion (#999)

* [Flax] Add finetune Stable Diffusion

* temporary fix

* drop_last and seed

* add dtype for mixed precision training

* style

* Add Flax example
parent 3be9fa97
...@@ -62,6 +62,24 @@ accelerate launch train_text_to_image.py \ ...@@ -62,6 +62,24 @@ accelerate launch train_text_to_image.py \
--output_dir="sd-pokemon-model" --output_dir="sd-pokemon-model"
``` ```
Or use the Flax implementation if you need a speedup
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export dataset_name="lambdalabs/pokemon-blip-captions"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-pokemon-model"
```
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
...@@ -86,6 +104,24 @@ accelerate launch train_text_to_image.py \ ...@@ -86,6 +104,24 @@ accelerate launch train_text_to_image.py \
--output_dir="sd-pokemon-model" --output_dir="sd-pokemon-model"
``` ```
Or use the Flax implementation if you need a speedup
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export TRAIN_DIR="path_to_your_dataset"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-pokemon-model"
```
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline` Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment