Unverified Commit a1d55e14 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Change the default `weighting_scheme` in the SD3 scripts (#8639)

* change to logit_normal as the weighting scheme

* sensible default mote
parent e5564d45
......@@ -11,6 +11,8 @@ The `train_dreambooth_sd3.py` script shows how to implement the training procedu
huggingface-cli login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
......@@ -52,8 +54,6 @@ write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
......@@ -72,8 +72,6 @@ snapshot_download(
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
```bash
......@@ -116,6 +114,8 @@ To better track our training experiments, we're using the following flags in the
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
To perform DreamBooth with LoRA, run:
```bash
......@@ -142,3 +142,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
--seed="0" \
--push_to_hub
```
## Other notes
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
\ No newline at end of file
......@@ -477,7 +477,10 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
"--weighting_scheme",
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
......
......@@ -472,7 +472,10 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
"--weighting_scheme",
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment