Unverified Commit 6fac1369 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

Add features to the Dreambooth LoRA SDXL training script (#5508)



* Additions:
- support for different lr for text encoder
- support for Prodigy optimizer
- support for min snr gamma
- support for custom captions and dataset loading from the hub

* adjusted --caption_column behaviour (to -not- use the second column of the dataset by default if --caption_column is not provided)

* fixed --output_dir / --model_dir_name confusion

* added --repeats, --adam_weight_decay_text_encoder
+ some fixes

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* - import compute_snr from diffusers/training_utils.py
- cluster adamw together
- when using 'prodigy', if --train_text_...
parent 1093f9d6
...@@ -421,6 +421,49 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -421,6 +421,49 @@ class ExamplesTestsAccelerate(unittest.TestCase):
) )
self.assertTrue(starts_with_unet) self.assertTrue(starts_with_unet)
def test_dreambooth_lora_sdxl_custom_captions(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--caption_column text
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--caption_column text
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--train_text_encoder
""".split()
run_command(self._launch_args + test_args)
def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment