"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "dc3e0ca59bf26ebcc9f12ed186bfe8fca86c3a1b"
Unverified Commit ee4ab238 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[SD3 dreambooth-lora training] small updates + bug fixes (#9682)

* add latent caching + smol updates

* update license

* replace with free_memory

* add --upcast_before_saving to allow saving transformer weights in lower precision

* fix models to accumulate

* fix mixed precision issue as proposed in https://github.com/huggingface/diffusers/pull/9565



* smol update to readme

* style

* fix caching latents

* style

* add tests for latent caching

* style

* fix latent caching

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent cef4f65c
...@@ -136,7 +136,7 @@ accelerate launch train_dreambooth_lora_sd3.py \ ...@@ -136,7 +136,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
--resolution=512 \ --resolution=512 \
--train_batch_size=1 \ --train_batch_size=1 \
--gradient_accumulation_steps=4 \ --gradient_accumulation_steps=4 \
--learning_rate=1e-5 \ --learning_rate=4e-4 \
--report_to="wandb" \ --report_to="wandb" \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
......
...@@ -103,6 +103,39 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate): ...@@ -103,6 +103,39 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
) )
self.assertTrue(starts_with_expected_prefix) self.assertTrue(starts_with_expected_prefix)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self): def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
test_args = f""" test_args = f"""
......
...@@ -140,7 +140,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/ ...@@ -140,7 +140,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
model_card = load_or_create_model_card( model_card = load_or_create_model_card(
repo_id_or_path=repo_id, repo_id_or_path=repo_id,
from_training=True, from_training=True,
license="openrail++", license="other",
base_model=base_model, base_model=base_model,
prompt=instance_prompt, prompt=instance_prompt,
model_description=model_description, model_description=model_description,
...@@ -186,7 +186,7 @@ def log_validation( ...@@ -186,7 +186,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -608,6 +608,12 @@ def parse_args(input_args=None): ...@@ -608,6 +608,12 @@ def parse_args(input_args=None):
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
), ),
) )
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument( parser.add_argument(
"--report_to", "--report_to",
type=str, type=str,
...@@ -628,6 +634,15 @@ def parse_args(input_args=None): ...@@ -628,6 +634,15 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
), ),
) )
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument( parser.add_argument(
"--prior_generation_precision", "--prior_generation_precision",
type=str, type=str,
...@@ -1394,6 +1409,16 @@ def main(args): ...@@ -1394,6 +1409,16 @@ def main(args):
logger.warning( logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
) )
if args.train_text_encoder and args.text_encoder_lr:
logger.warning(
f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
...@@ -1440,6 +1465,9 @@ def main(args): ...@@ -1440,6 +1465,9 @@ def main(args):
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not args.train_text_encoder and not train_dataset.custom_instance_prompts: if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers args.instance_prompt, text_encoders, tokenizers
...@@ -1484,6 +1512,21 @@ def main(args): ...@@ -1484,6 +1512,21 @@ def main(args):
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0) tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)
vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if args.validation_prompt is None:
del vae
free_memory()
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
...@@ -1500,7 +1543,6 @@ def main(args): ...@@ -1500,7 +1543,6 @@ def main(args):
power=args.lr_power, power=args.lr_power,
) )
# Prepare everything with our `accelerator`.
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
if args.train_text_encoder: if args.train_text_encoder:
( (
...@@ -1607,8 +1649,9 @@ def main(args): ...@@ -1607,8 +1649,9 @@ def main(args):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer] models_to_accumulate = [transformer]
if args.train_text_encoder:
models_to_accumulate.extend([text_encoder_one, text_encoder_two])
with accelerator.accumulate(models_to_accumulate): with accelerator.accumulate(models_to_accumulate):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"] prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image - # encode batch prompts when custom prompts are provided for each image -
...@@ -1639,8 +1682,13 @@ def main(args): ...@@ -1639,8 +1682,13 @@ def main(args):
) )
# Convert images to latent space # Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample() if args.cache_latents:
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype) model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
...@@ -1773,6 +1821,8 @@ def main(args): ...@@ -1773,6 +1821,8 @@ def main(args):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
) )
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = StableDiffusion3Pipeline.from_pretrained( pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
...@@ -1793,15 +1843,18 @@ def main(args): ...@@ -1793,15 +1843,18 @@ def main(args):
epoch=epoch, epoch=epoch,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three del text_encoder_one, text_encoder_two, text_encoder_three
free_memory() free_memory()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
transformer = unwrap_model(transformer) transformer = unwrap_model(transformer)
transformer = transformer.to(torch.float32) if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer) transformer_lora_layers = get_peft_model_state_dict(transformer)
if args.train_text_encoder: if args.train_text_encoder:
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import argparse import argparse
import copy import copy
import gc
import itertools import itertools
import logging import logging
import math import math
...@@ -51,7 +50,7 @@ from diffusers import ( ...@@ -51,7 +50,7 @@ from diffusers import (
StableDiffusion3Pipeline, StableDiffusion3Pipeline,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
is_wandb_available, is_wandb_available,
...@@ -119,7 +118,7 @@ Please adhere to the licensing terms as described `[here](https://huggingface.co ...@@ -119,7 +118,7 @@ Please adhere to the licensing terms as described `[here](https://huggingface.co
model_card = load_or_create_model_card( model_card = load_or_create_model_card(
repo_id_or_path=repo_id, repo_id_or_path=repo_id,
from_training=True, from_training=True,
license="openrail++", license="other",
base_model=base_model, base_model=base_model,
prompt=instance_prompt, prompt=instance_prompt,
model_description=model_description, model_description=model_description,
...@@ -164,7 +163,7 @@ def log_validation( ...@@ -164,7 +163,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -190,8 +189,7 @@ def log_validation( ...@@ -190,8 +189,7 @@ def log_validation(
) )
del pipeline del pipeline
if torch.cuda.is_available(): free_memory()
torch.cuda.empty_cache()
return images return images
...@@ -1065,8 +1063,7 @@ def main(args): ...@@ -1065,8 +1063,7 @@ def main(args):
image.save(image_filename) image.save(image_filename)
del pipeline del pipeline
if torch.cuda.is_available(): free_memory()
torch.cuda.empty_cache()
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -1386,9 +1383,7 @@ def main(args): ...@@ -1386,9 +1383,7 @@ def main(args):
del tokenizers, text_encoders del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect() free_memory()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't # pack the statically computed variables appropriately here. This is so that we don't
...@@ -1708,6 +1703,9 @@ def main(args): ...@@ -1708,6 +1703,9 @@ def main(args):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
) )
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
text_encoder_three.to(weight_dtype)
pipeline = StableDiffusion3Pipeline.from_pretrained( pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
...@@ -1730,8 +1728,7 @@ def main(args): ...@@ -1730,8 +1728,7 @@ def main(args):
) )
if not args.train_text_encoder: if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three del text_encoder_one, text_encoder_two, text_encoder_three
torch.cuda.empty_cache() free_memory()
gc.collect()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment