"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2715079344b725bdb045f601551dae02509e393e"
Unverified Commit 1a8b3c2e authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Training] SD3 training fixes (#8917)



* SD3 training fixes
Co-authored-by: default avatarbghira <59658056+bghira@users.noreply.github.com>

* rewrite noise addition part to respect the eqn.

* styler

* Update examples/dreambooth/README_sd3.md
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>

---------
Co-authored-by: default avatarbghira <59658056+bghira@users.noreply.github.com>
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>
parent 56e772ab
...@@ -183,4 +183,6 @@ accelerate launch train_dreambooth_lora_sd3.py \ ...@@ -183,4 +183,6 @@ accelerate launch train_dreambooth_lora_sd3.py \
## Other notes ## Other notes
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities. 1. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
\ No newline at end of file 2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in [#8917](https://github.com/huggingface/diffusers/pull/8917).
3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well.
\ No newline at end of file
...@@ -523,6 +523,13 @@ def parse_args(input_args=None): ...@@ -523,6 +523,13 @@ def parse_args(input_args=None):
default=1.29, default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
) )
parser.add_argument(
"--precondition_outputs",
type=int,
default=1,
help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
"model `target` is calculated.",
)
parser.add_argument( parser.add_argument(
"--optimizer", "--optimizer",
type=str, type=str,
...@@ -1636,7 +1643,7 @@ def main(args): ...@@ -1636,7 +1643,7 @@ def main(args):
# Convert images to latent space # Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype) model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
...@@ -1656,8 +1663,9 @@ def main(args): ...@@ -1656,8 +1663,9 @@ def main(args):
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching. # Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# Predict the noise residual # Predict the noise residual
model_pred = transformer( model_pred = transformer(
...@@ -1670,14 +1678,18 @@ def main(args): ...@@ -1670,14 +1678,18 @@ def main(args):
# Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs. # Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input if args.precondition_outputs:
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling # these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss # and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss # flow matching loss
target = model_input if args.precondition_outputs:
target = model_input
else:
target = noise - model_input
if args.with_prior_preservation: if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately. # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
......
...@@ -494,6 +494,13 @@ def parse_args(input_args=None): ...@@ -494,6 +494,13 @@ def parse_args(input_args=None):
default=1.29, default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
) )
parser.add_argument(
"--precondition_outputs",
type=int,
default=1,
help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
"model `target` is calculated.",
)
parser.add_argument( parser.add_argument(
"--optimizer", "--optimizer",
type=str, type=str,
...@@ -1549,7 +1556,7 @@ def main(args): ...@@ -1549,7 +1556,7 @@ def main(args):
# Convert images to latent space # Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype) model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
...@@ -1569,8 +1576,9 @@ def main(args): ...@@ -1569,8 +1576,9 @@ def main(args):
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching. # Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# Predict the noise residual # Predict the noise residual
if not args.train_text_encoder: if not args.train_text_encoder:
...@@ -1598,13 +1606,18 @@ def main(args): ...@@ -1598,13 +1606,18 @@ def main(args):
# Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs. # Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input if args.precondition_outputs:
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling # these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss # and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss # flow matching loss
target = model_input if args.precondition_outputs:
target = model_input
else:
target = noise - model_input
if args.with_prior_preservation: if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately. # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment