Unverified Commit 4a06c745 authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

Min-SNR Gamma: follow-up fix for zero-terminal SNR models on v-prediction or epsilon (#5177)



* merge with main

* fix flax example

* fix onnx example

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 89d8f848
...@@ -907,10 +907,17 @@ def main(): ...@@ -907,10 +907,17 @@ def main():
if args.snr_gamma is not None: if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps)) snr = jnp.array(compute_snr(timesteps))
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
if noise_scheduler.config.prediction_type == "v_prediction": if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1. snr_loss_weights = base_weights + 1
snr_loss_weights = snr_loss_weights + 1 else:
# Epsilon and sample prediction use the base weights.
snr_loss_weights = base_weights
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
snr_loss_weights[snr == 0] = 1.0
loss = loss * snr_loss_weights loss = loss * snr_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -801,9 +801,22 @@ def main(): ...@@ -801,9 +801,22 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps) snr = compute_snr(timesteps)
mse_loss_weights = ( base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -654,9 +654,22 @@ def main(): ...@@ -654,9 +654,22 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps) snr = compute_snr(timesteps)
mse_loss_weights = ( base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -685,9 +685,22 @@ def main(): ...@@ -685,9 +685,22 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps) snr = compute_snr(timesteps)
mse_loss_weights = ( base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -833,9 +833,22 @@ def main(): ...@@ -833,9 +833,22 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps) snr = compute_snr(timesteps)
mse_loss_weights = ( base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -872,12 +872,21 @@ def main(): ...@@ -872,12 +872,21 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps) snr = compute_snr(timesteps)
mse_loss_weights = ( base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction": if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1. # velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1 mse_loss_weights = base_weight + 1
else:
# Epsilon and sample prediction use the base weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -952,12 +952,22 @@ def main(): ...@@ -952,12 +952,22 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps) snr = compute_snr(timesteps)
mse_loss_weights = ( base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction": if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1. # Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = mse_loss_weights + 1 mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -783,12 +783,22 @@ def main(): ...@@ -783,12 +783,22 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps) snr = compute_snr(timesteps)
mse_loss_weights = ( base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction": if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1. # Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = mse_loss_weights + 1 mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -1072,12 +1072,22 @@ def main(args): ...@@ -1072,12 +1072,22 @@ def main(args):
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps) snr = compute_snr(timesteps)
mse_loss_weights = ( base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction": if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1. # Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = mse_loss_weights + 1 mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -1100,6 +1100,11 @@ def main(args): ...@@ -1100,6 +1100,11 @@ def main(args):
# Epsilon and sample both use the same loss weights. # Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment