Unverified Commit 2a97067b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Experimental] Diffusion LoRA DPO training (#6422)

* add: experimental script for diffusion dpo training.

* random_crop cli.

* fix: caption tokenization.

* fix: pixel_values index.

* fix: grad?

* debug

* fix: reduction.

* fixes in the loss calculation.

* style

* fix: unwrap call.

* fix: validation inference.

* add: initial sdxl script

* debug

* make sure images in the tuple are of same res

* fix model_max_length

* report print

* boom

* fix: numerical issues.

* fix: resolution

* comment about resize.

* change the order of the training transformation.

* save call.

* debug

* remove print

* manually detaching necessary?

* use the same vae for validation.

* add: readme.
parent ae060fc4
# Diffusion Model Alignment Using Direct Preference Optimization
This directory provides LoRA implementations of Diffusion DPO proposed in [DiffusionModel Alignment Using Direct Preference Optimization](https://arxiv.org/abs/2311.12908) by Bram Wallace, Meihua Dang, Rafael Rafailov, Linqi Zhou, Aaron Lou, Senthil Purushwalkam, Stefano Ermon, Caiming Xiong, Shafiq Joty, and Nikhil Naik.
We provide implementations for both Stable Diffusion (SD) and Stable Diffusion XL (SDXL). The original checkpoints are available at the URLs below:
* [mhdang/dpo-sd1.5-text2image-v1](https://huggingface.co/mhdang/dpo-sd1.5-text2image-v1)
* [mhdang/dpo-sdxl-text2image-v1](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1)
> 💡 Note: The scripts are highly experimental and were only tested on low-data regimes. Proceed with caution. Feel free to let us know about your findings via GitHub issues.
## SD training command
```bash
accelerate launch train_diffusion_dpo.py \
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
--output_dir="diffusion-dpo" \
--mixed_precision="fp16" \
--dataset_name=kashif/pickascore \
--resolution=512 \
--train_batch_size=16 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing \
--use_8bit_adam \
--rank=8 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=10000 \
--checkpointing_steps=2000 \
--run_validation --validation_steps=200 \
--seed="0" \
--report_to="wandb" \
--push_to_hub
```
## SDXL training command
```bash
accelerate launch train_diffusion_dpo_sdxl.py \
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir="diffusion-sdxl-dpo" \
--mixed_precision="fp16" \
--dataset_name=kashif/pickascore \
--train_batch_size=8 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing \
--use_8bit_adam \
--rank=8 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2000 \
--checkpointing_steps=500 \
--run_validation --validation_steps=50 \
--seed="0" \
--report_to="wandb" \
--push_to_hub
```
## Acknowledgements
This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.
\ No newline at end of file
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft
wandb
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment