Unverified Commit a1d55e14 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Change the default `weighting_scheme` in the SD3 scripts (#8639)

* change to logit_normal as the weighting scheme

* sensible default mote
parent e5564d45
...@@ -11,6 +11,8 @@ The `train_dreambooth_sd3.py` script shows how to implement the training procedu ...@@ -11,6 +11,8 @@ The `train_dreambooth_sd3.py` script shows how to implement the training procedu
huggingface-cli login huggingface-cli login
``` ```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch ## Running locally with PyTorch
### Installing the dependencies ### Installing the dependencies
...@@ -52,8 +54,6 @@ write_basic_config() ...@@ -52,8 +54,6 @@ write_basic_config()
``` ```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example ### Dog toy example
...@@ -72,8 +72,6 @@ snapshot_download( ...@@ -72,8 +72,6 @@ snapshot_download(
) )
``` ```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using: Now, we can launch training using:
```bash ```bash
...@@ -116,6 +114,8 @@ To better track our training experiments, we're using the following flags in the ...@@ -116,6 +114,8 @@ To better track our training experiments, we're using the following flags in the
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
To perform DreamBooth with LoRA, run: To perform DreamBooth with LoRA, run:
```bash ```bash
...@@ -142,3 +142,7 @@ accelerate launch train_dreambooth_lora_sd3.py \ ...@@ -142,3 +142,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
--seed="0" \ --seed="0" \
--push_to_hub --push_to_hub
``` ```
## Other notes
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
\ No newline at end of file
...@@ -477,7 +477,10 @@ def parse_args(input_args=None): ...@@ -477,7 +477,10 @@ def parse_args(input_args=None):
), ),
) )
parser.add_argument( parser.add_argument(
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"] "--weighting_scheme",
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
) )
parser.add_argument( parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
......
...@@ -472,7 +472,10 @@ def parse_args(input_args=None): ...@@ -472,7 +472,10 @@ def parse_args(input_args=None):
), ),
) )
parser.add_argument( parser.add_argument(
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"] "--weighting_scheme",
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
) )
parser.add_argument( parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment