Unverified Commit abe05822 authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

[Flax] Add finetune Stable Diffusion (#999)

* [Flax] Add finetune Stable Diffusion

* temporary fix

* drop_last and seed

* add dtype for mixed precision training

* style

* Add Flax example
parent 3be9fa97
......@@ -62,6 +62,24 @@ accelerate launch train_text_to_image.py \
--output_dir="sd-pokemon-model"
```
Or use the Flax implementation if you need a speedup
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export dataset_name="lambdalabs/pokemon-blip-captions"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-pokemon-model"
```
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
......@@ -86,6 +104,24 @@ accelerate launch train_text_to_image.py \
--output_dir="sd-pokemon-model"
```
Or use the Flax implementation if you need a speedup
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export TRAIN_DIR="path_to_your_dataset"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-pokemon-model"
```
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment