"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "029a9da00b4015b6988a28e353fc389572ac6254"
Unverified Commit 83062fb8 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Advanced DreamBooth LoRA SDXL] Support EDM-style training (follow up of #7126) (#7182)



* add edm style training

* style

* finish adding edm training feature

* import fix

* fix latents mean

* minor adjustments

* add edm to readme

* style

* fix autocast and scheduler config issues when using edm

* style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent b6d7e31d
...@@ -259,6 +259,50 @@ pip install git+https://github.com/huggingface/peft.git ...@@ -259,6 +259,50 @@ pip install git+https://github.com/huggingface/peft.git
**Inference** **Inference**
The inference is the same as if you train a regular LoRA 🤗 The inference is the same as if you train a regular LoRA 🤗
## Conducting EDM-style training
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
simply set:
```diff
+ --do_edm_style_training \
```
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
```bash
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
--dataset_name="linoyts/3d_icon" \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir="3d-icon-SDXL-LoRA" \
--do_edm_style_training \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=3 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--train_text_encoder_ti_frac=0.5\
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
> [!CAUTION]
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
### Tips and Tricks ### Tips and Tricks
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices) Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import argparse import argparse
import contextlib
import gc import gc
import hashlib import hashlib
import itertools import itertools
import json
import logging import logging
import math import math
import os import os
...@@ -37,7 +39,7 @@ import transformers ...@@ -37,7 +39,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, hf_hub_download, upload_folder
from packaging import version from packaging import version
from peft import LoraConfig, set_peft_model_state_dict from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
...@@ -55,6 +57,8 @@ from diffusers import ( ...@@ -55,6 +57,8 @@ from diffusers import (
AutoencoderKL, AutoencoderKL,
DDPMScheduler, DDPMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
EDMEulerScheduler,
EulerDiscreteScheduler,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
...@@ -79,6 +83,20 @@ check_min_version("0.27.0.dev0") ...@@ -79,6 +83,20 @@ check_min_version("0.27.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def determine_scheduler_type(pretrained_model_name_or_path, revision):
model_index_filename = "model_index.json"
if os.path.isdir(pretrained_model_name_or_path):
model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
else:
model_index = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
)
with open(model_index, "r") as f:
scheduler_type = json.load(f)["scheduler"][1]
return scheduler_type
def save_model_card( def save_model_card(
repo_id: str, repo_id: str,
use_dora: bool, use_dora: bool,
...@@ -370,6 +388,11 @@ def parse_args(input_args=None): ...@@ -370,6 +388,11 @@ def parse_args(input_args=None):
" `args.validation_prompt` multiple times: `args.num_validation_images`." " `args.validation_prompt` multiple times: `args.num_validation_images`."
), ),
) )
parser.add_argument(
"--do_edm_style_training",
action="store_true",
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
)
parser.add_argument( parser.add_argument(
"--with_prior_preservation", "--with_prior_preservation",
default=False, default=False,
...@@ -1117,6 +1140,8 @@ def main(args): ...@@ -1117,6 +1140,8 @@ def main(args):
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub." " Please use `huggingface-cli login` to authenticate with the Hub."
) )
if args.do_edm_style_training and args.snr_gamma is not None:
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -1234,7 +1259,19 @@ def main(args): ...@@ -1234,7 +1259,19 @@ def main(args):
) )
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
if "EDM" in scheduler_type:
args.do_edm_style_training = True
noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
logger.info("Performing EDM-style training!")
elif args.do_edm_style_training:
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
logger.info("Performing EDM-style training!")
else:
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
...@@ -1252,7 +1289,12 @@ def main(args): ...@@ -1252,7 +1289,12 @@ def main(args):
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
) )
vae_scaling_factor = vae.config.scaling_factor latents_mean = latents_std = None
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
...@@ -1790,6 +1832,19 @@ def main(args): ...@@ -1790,6 +1832,19 @@ def main(args):
disable=not accelerator.is_local_main_process, disable=not accelerator.is_local_main_process,
) )
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# TODO: revisit other sampling algorithms
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
if args.train_text_encoder: if args.train_text_encoder:
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
elif args.train_text_encoder_ti: # args.train_text_encoder_ti elif args.train_text_encoder_ti: # args.train_text_encoder_ti
...@@ -1841,9 +1896,15 @@ def main(args): ...@@ -1841,9 +1896,15 @@ def main(args):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype) pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae_scaling_factor if latents_mean is None and latents_std is None:
if args.pretrained_vae_model_name_or_path is None: model_input = model_input * vae.config.scaling_factor
model_input = model_input.to(weight_dtype) if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
else:
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(model_input) noise = torch.randn_like(model_input)
...@@ -1854,15 +1915,32 @@ def main(args): ...@@ -1854,15 +1915,32 @@ def main(args):
) )
bsz = model_input.shape[0] bsz = model_input.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( if not args.do_edm_style_training:
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device timesteps = torch.randint(
) 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
timesteps = timesteps.long() )
timesteps = timesteps.long()
else:
# in EDM formulation, the model is conditioned on the pre-conditioned noise levels
# instead of discrete timesteps, so here we sample indices to get the noise levels
# from `scheduler.timesteps`
indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
# Add noise to the model input according to the noise magnitude at each timestep # Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
# We then precondition the final model inputs based on these sigmas instead of the timesteps.
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
if args.do_edm_style_training:
sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
if "EDM" in scheduler_type:
inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
else:
inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
# time ids # time ids
add_time_ids = torch.cat( add_time_ids = torch.cat(
...@@ -1888,7 +1966,7 @@ def main(args): ...@@ -1888,7 +1966,7 @@ def main(args):
} }
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet( model_pred = unet(
noisy_model_input, inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
timesteps, timesteps,
prompt_embeds_input, prompt_embeds_input,
added_cond_kwargs=unet_added_conditions, added_cond_kwargs=unet_added_conditions,
...@@ -1906,14 +1984,42 @@ def main(args): ...@@ -1906,14 +1984,42 @@ def main(args):
) )
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet( model_pred = unet(
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
).sample ).sample
weighting = None
if args.do_edm_style_training:
# Similar to the input preconditioning, the model predictions are also preconditioned
# on noised model inputs (before preconditioning) and the sigmas.
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
if "EDM" in scheduler_type:
model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
else:
if noise_scheduler.config.prediction_type == "epsilon":
model_pred = model_pred * (-sigmas) + noisy_model_input
elif noise_scheduler.config.prediction_type == "v_prediction":
model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
noisy_model_input / (sigmas**2 + 1)
)
# We are not doing weighting here because it tends result in numerical problems.
# See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
# There might be other alternatives for weighting as well:
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
if "EDM" not in scheduler_type:
weighting = (sigmas**-2.0).float()
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon": if noise_scheduler.config.prediction_type == "epsilon":
target = noise target = model_input if args.do_edm_style_training else noise
elif noise_scheduler.config.prediction_type == "v_prediction": elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps) target = (
model_input
if args.do_edm_style_training
else noise_scheduler.get_velocity(model_input, noise, timesteps)
)
else: else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
...@@ -1923,10 +2029,28 @@ def main(args): ...@@ -1923,10 +2029,28 @@ def main(args):
target, target_prior = torch.chunk(target, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0)
# Compute prior loss # Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") if weighting is not None:
prior_loss = torch.mean(
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
target_prior.shape[0], -1
),
1,
)
prior_loss = prior_loss.mean()
else:
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
if args.snr_gamma is None: if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") if weighting is not None:
loss = torch.mean(
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
target.shape[0], -1
),
1,
)
loss = loss.mean()
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else: else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
...@@ -2049,17 +2173,18 @@ def main(args): ...@@ -2049,17 +2173,18 @@ def main(args):
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {} scheduler_args = {}
if "variance_type" in pipeline.scheduler.config: if not args.do_edm_style_training:
variance_type = pipeline.scheduler.config.variance_type if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]: if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small" variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config( pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args pipeline.scheduler.config, **scheduler_args
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
...@@ -2067,8 +2192,13 @@ def main(args): ...@@ -2067,8 +2192,13 @@ def main(args):
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
inference_ctx = (
contextlib.nullcontext()
if "playground" in args.pretrained_model_name_or_path
else torch.cuda.amp.autocast()
)
with torch.cuda.amp.autocast(): with inference_ctx:
images = [ images = [
pipeline(**pipeline_args, generator=generator).images[0] pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images) for _ in range(args.num_validation_images)
...@@ -2144,15 +2274,18 @@ def main(args): ...@@ -2144,15 +2274,18 @@ def main(args):
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {} scheduler_args = {}
if "variance_type" in pipeline.scheduler.config: if not args.do_edm_style_training:
variance_type = pipeline.scheduler.config.variance_type if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]: if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small" variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
# load attention processors # load attention processors
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment