Unverified Commit 572d8e20 authored by Andrés Mauricio Repetto Ferrero's avatar Andrés Mauricio Repetto Ferrero Committed by GitHub
Browse files

Adding better way to define multiple concepts and also validation capabilities. (#3807)



* - Added validation parameters
- Changed some parameter descriptions to better explain their use.
- Fixed a few typos.
- Added concept_list parameter for better management of multiple subjects
- changed logic for image validation

* - Fixed bad logic for class data root directories

* Defaulting validation_steps to None for an easier logic

* Fixed multiple validation prompts

* Fixed bug on validation negative prompt

* Changed validation logic for tracker.

* Added uuid for validation image labeling

* Fix error when comparing validation prompts and validation negative prompts

* Improved error message when negative prompts for validation are more than the number of prompts

* - Changed image tracking number from epoch to global_step
- Added Typing for functions

* Added some validations more when using concept_list parameter and the regular ones.

* Fixed error message

* Added more validations for validation parameters

* Improved messaging for errors

* Fixed validation error for parameters with default values

* - Added train step to image name for validation
- reformatted code

* - Added train step to image's name for validation
- reformatted code

* Updated README.md file.

* reverted back original script of train_dreambooth.py

* reverted back original script of train_dreambooth.py

* left one blank line at the eof

* reverted back setup.py

* reverted back setup.py

* added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter.

* Ran black formatter.

* fixed a few strings

* fixed import sort with isort and removed fstrings without placeholder

* fixed import order with ruff (since with isort wasn't ok)

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2e8668f0
......@@ -86,6 +86,53 @@ This example shows training for 2 subjects, but please note that the model can b
Note also that in this script, `sks` and `t@y` were used as tokens to learn the new subjects ([this thread](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/71) inspired the use of `t@y` as our second identifier). However, there may be better rare tokens to experiment with, and results also seemed to be good when more intuitive words are used.
**Important**: New parameters are added to the script, making possible to validate the progress of the training by
generating images at specified steps. Taking also into account that a comma separated list in a text field for a prompt
it's never a good idea (simply because it is very common in prompts to have them as part of a regular text) we
introduce the `concept_list` parameter: allowing to specify a json-like file where you can define the different
configuration for each subject that you want to train.
An example of how to generate the file:
```python
import json
# here we are using parameters for prior-preservation and validation as well.
concepts_list = [
{
"instance_prompt": "drawing of a t@y meme",
"class_prompt": "drawing of a meme",
"instance_data_dir": "/some_folder/meme_toy",
"class_data_dir": "/data/meme",
"validation_prompt": "drawing of a t@y meme about football in Uruguay",
"validation_negative_prompt": "black and white"
},
{
"instance_prompt": "drawing of a sks sir",
"class_prompt": "drawing of a sir",
"instance_data_dir": "/some_other_folder/sir_sks",
"class_data_dir": "/data/sir",
"validation_prompt": "drawing of a sks sir with the Uruguayan sun in his chest",
"validation_negative_prompt": "an old man",
"validation_guidance_scale": 20,
"validation_number_images": 3,
"validation_inference_steps": 10
}
]
with open("concepts_list.json", "w") as f:
json.dump(concepts_list, f, indent=4)
```
And then just point to the file when executing the script:
```bash
# exports...
accelerate launch train_multi_subject_dreambooth.py \
# more parameters...
--concepts_list="concepts_list.json"
```
You can use the helper from the script to get a better sense of each parameter.
### Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment