Unverified Commit 74e43a4f authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

Resolve v_prediction issue for min-SNR gamma weighted loss function (#5096)



* Resolve v_prediction issue for min-SNR gamma weighted loss function

* Combine MSE loss calculation of epsilon and velocity, with a note about the application of the epsilon code to sample prediction

* style

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 81331f3b
......@@ -332,15 +332,6 @@ def parse_args(input_args=None):
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--force_snr_gamma",
action="store_true",
help=(
"When using SNR gamma with rescaled betas for zero terminal SNR, a divide-by-zero error can cause NaN"
" condition when computing the SNR with a sigma value of zero. This parameter overrides the check,"
" allowing the use of SNR gamma with a terminal SNR model. Use with caution, and closely monitor results."
),
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument(
"--allow_tf32",
......@@ -554,18 +545,6 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# Check for terminal SNR in combination with SNR Gamma
if (
args.snr_gamma
and not args.force_snr_gamma
and (
hasattr(noise_scheduler.config, "rescale_betas_zero_snr") and noise_scheduler.config.rescale_betas_zero_snr
)
):
raise ValueError(
f"The selected noise scheduler for the model {args.pretrained_model_name_or_path} uses rescaled betas for zero SNR.\n"
"When this configuration is present, the parameter --snr_gamma may not be used without parameter --force_snr_gamma.\n"
"This is due to a mathematical incompatibility between our current SNR gamma implementation, and a sigma value of zero."
)
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
......@@ -1013,9 +992,17 @@ def main(args):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment