Unverified Commit 0fc2fb71 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

dreambooth upscaling fix added latents (#3659)

parent 523a50a8
...@@ -540,10 +540,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I ...@@ -540,10 +540,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I
For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
LoRA finetuning stage II. LoRA finetuning stage II.
For finegrained detail like faces, we find that lower learning rates work best. For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.
For stage II, we find that lower learning rates are also needed. For stage II, we find that lower learning rates are also needed.
We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler
used in the training scripts.
### Stage II additional validation images ### Stage II additional validation images
The stage II validation requires images to upscale, we can download a downsized version of the training set: The stage II validation requires images to upscale, we can download a downsized version of the training set:
...@@ -631,7 +634,8 @@ with a T5 loaded from the original model. ...@@ -631,7 +634,8 @@ with a T5 loaded from the original model.
`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam. `use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. `--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is
likely the learning rate can be increased with larger batch sizes.
Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM. Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
...@@ -656,7 +660,7 @@ accelerate launch train_dreambooth.py \ ...@@ -656,7 +660,7 @@ accelerate launch train_dreambooth.py \
--text_encoder_use_attention_mask \ --text_encoder_use_attention_mask \
--tokenizer_max_length 77 \ --tokenizer_max_length 77 \
--pre_compute_text_embeddings \ --pre_compute_text_embeddings \
--use_8bit_adam \ # --use_8bit_adam \
--set_grads_to_none \ --set_grads_to_none \
--skip_save_text_encoder \ --skip_save_text_encoder \
--push_to_hub --push_to_hub
...@@ -664,10 +668,14 @@ accelerate launch train_dreambooth.py \ ...@@ -664,10 +668,14 @@ accelerate launch train_dreambooth.py \
### IF Stage II Full Dreambooth ### IF Stage II Full Dreambooth
`--learning_rate=1e-8`: Even lower learning rate. `--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as
1e-8.
`--resolution=256`: The upscaler expects higher resolution inputs `--resolution=256`: The upscaler expects higher resolution inputs
`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with
faces required large effective batch sizes.
```sh ```sh
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0" export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
export INSTANCE_DIR="dog" export INSTANCE_DIR="dog"
...@@ -682,8 +690,8 @@ accelerate launch train_dreambooth.py \ ...@@ -682,8 +690,8 @@ accelerate launch train_dreambooth.py \
--instance_prompt="a sks dog" \ --instance_prompt="a sks dog" \
--resolution=256 \ --resolution=256 \
--train_batch_size=2 \ --train_batch_size=2 \
--gradient_accumulation_steps=2 \ --gradient_accumulation_steps=6 \
--learning_rate=1e-8 \ --learning_rate=5e-6 \
--max_train_steps=2000 \ --max_train_steps=2000 \
--validation_prompt="a sks dog" \ --validation_prompt="a sks dog" \
--validation_steps=150 \ --validation_steps=150 \
......
...@@ -574,10 +574,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I ...@@ -574,10 +574,13 @@ upscaler to remove the new token from the instance prompt. I.e. if your stage I
For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
LoRA finetuning stage II. LoRA finetuning stage II.
For finegrained detail like faces, we find that lower learning rates work best. For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.
For stage II, we find that lower learning rates are also needed. For stage II, we find that lower learning rates are also needed.
We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler
used in the training scripts.
### Stage II additional validation images ### Stage II additional validation images
The stage II validation requires images to upscale, we can download a downsized version of the training set: The stage II validation requires images to upscale, we can download a downsized version of the training set:
...@@ -665,7 +668,8 @@ with a T5 loaded from the original model. ...@@ -665,7 +668,8 @@ with a T5 loaded from the original model.
`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam. `use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. `--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is
likely the learning rate can be increased with larger batch sizes.
Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM. Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
...@@ -690,7 +694,7 @@ accelerate launch train_dreambooth.py \ ...@@ -690,7 +694,7 @@ accelerate launch train_dreambooth.py \
--text_encoder_use_attention_mask \ --text_encoder_use_attention_mask \
--tokenizer_max_length 77 \ --tokenizer_max_length 77 \
--pre_compute_text_embeddings \ --pre_compute_text_embeddings \
--use_8bit_adam \ # --use_8bit_adam \
--set_grads_to_none \ --set_grads_to_none \
--skip_save_text_encoder \ --skip_save_text_encoder \
--push_to_hub --push_to_hub
...@@ -698,10 +702,14 @@ accelerate launch train_dreambooth.py \ ...@@ -698,10 +702,14 @@ accelerate launch train_dreambooth.py \
### IF Stage II Full Dreambooth ### IF Stage II Full Dreambooth
`--learning_rate=1e-8`: Even lower learning rate. `--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as
1e-8.
`--resolution=256`: The upscaler expects higher resolution inputs `--resolution=256`: The upscaler expects higher resolution inputs
`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with
faces required large effective batch sizes.
```sh ```sh
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0" export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
export INSTANCE_DIR="dog" export INSTANCE_DIR="dog"
...@@ -716,8 +724,8 @@ accelerate launch train_dreambooth.py \ ...@@ -716,8 +724,8 @@ accelerate launch train_dreambooth.py \
--instance_prompt="a sks dog" \ --instance_prompt="a sks dog" \
--resolution=256 \ --resolution=256 \
--train_batch_size=2 \ --train_batch_size=2 \
--gradient_accumulation_steps=2 \ --gradient_accumulation_steps=6 \
--learning_rate=1e-8 \ --learning_rate=5e-6 \
--max_train_steps=2000 \ --max_train_steps=2000 \
--validation_prompt="a sks dog" \ --validation_prompt="a sks dog" \
--validation_steps=150 \ --validation_steps=150 \
......
...@@ -52,7 +52,6 @@ from diffusers import ( ...@@ -52,7 +52,6 @@ from diffusers import (
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
if is_wandb_available(): if is_wandb_available():
...@@ -1212,14 +1211,8 @@ def main(args): ...@@ -1212,14 +1211,8 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
) )
if unet.config.in_channels > channels: if unet.config.in_channels == channels * 2:
needed_additional_channels = unet.config.in_channels - channels noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
additional_latents = randn_tensor(
(bsz, needed_additional_channels, height, width),
device=noisy_model_input.device,
dtype=noisy_model_input.dtype,
)
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps": if args.class_labels_conditioning == "timesteps":
class_labels = timesteps class_labels = timesteps
......
...@@ -60,7 +60,6 @@ from diffusers.models.attention_processor import ( ...@@ -60,7 +60,6 @@ from diffusers.models.attention_processor import (
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...@@ -1157,14 +1156,8 @@ def main(args): ...@@ -1157,14 +1156,8 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
) )
if unet.config.in_channels > channels: if unet.config.in_channels == channels * 2:
needed_additional_channels = unet.config.in_channels - channels noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
additional_latents = randn_tensor(
(bsz, needed_additional_channels, height, width),
device=noisy_model_input.device,
dtype=noisy_model_input.dtype,
)
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps": if args.class_labels_conditioning == "timesteps":
class_labels = timesteps class_labels = timesteps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment