Unverified Commit 6fac1369 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

Add features to the Dreambooth LoRA SDXL training script (#5508)



* Additions:
- support for different lr for text encoder
- support for Prodigy optimizer
- support for min snr gamma
- support for custom captions and dataset loading from the hub

* adjusted --caption_column behaviour (to -not- use the second column of the dataset by default if --caption_column is not provided)

* fixed --output_dir / --model_dir_name confusion

* added --repeats, --adam_weight_decay_text_encoder
+ some fixes

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* - import compute_snr from diffusers/training_utils.py
- cluster adamw together
- when using 'prodigy', if --train_text_encoder == True and --text_encoder_lr != --learning rate, changes the lr of the text encoders optimization params to be --learning_rate (otherwise errors)

* shape fixes when custom captions are used

* formatting and a little cleanup

* code styling

* --repeats default value fixed, changed to 1

* bug fix - removed redundant lines of embedding concatenation when using prior_preservation (that duplicated class_prompt embeddings)

* changed dataset loading logic according to the following usecases (to avoid unnecessary dependency on datasets)-
1. user provides --dataset_name
2. user provides local dir --instance_data_dir that contains a metadata .jsonl file
3. user provides local dir --instance_data_dir that contains only images
in cases [1,2] we import datasets and use load_dataset method, in case [3] we process the data same as in the original script setting

* styling fix

* arg name fix

* adjusted the --repeats logic

* -removed redundant arg and 'if' when loading local folder with prompts
-updated readme template
-some default val fixes
-custom caption tests

* image path fix for readme

* code style

* bug fix

* --caption_column arg

* readme fix

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarLinoy Tsaban <linoy@huggingface.co>
parent 1093f9d6
...@@ -421,6 +421,49 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -421,6 +421,49 @@ class ExamplesTestsAccelerate(unittest.TestCase):
) )
self.assertTrue(starts_with_unet) self.assertTrue(starts_with_unet)
def test_dreambooth_lora_sdxl_custom_captions(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--caption_column text
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--caption_column text
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--train_text_encoder
""".split()
run_command(self._launch_args + test_args)
def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment