Unverified Commit 02a8d664 authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

Min-SNR Gamma: correct the fix for SNR weighted loss in v-prediction … (#5238)



Min-SNR Gamma: correct the fix for SNR weighted loss in v-prediction by adding 1 to SNR rather than the resulting loss weights
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent e6faf607
...@@ -907,17 +907,10 @@ def main(): ...@@ -907,17 +907,10 @@ def main():
if args.snr_gamma is not None: if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps)) snr = jnp.array(compute_snr(timesteps))
base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
if noise_scheduler.config.prediction_type == "v_prediction": if noise_scheduler.config.prediction_type == "v_prediction":
snr_loss_weights = base_weights + 1 # Velocity objective requires that we add one to SNR values before we divide by them.
else: snr = snr + 1
# Epsilon and sample prediction use the base weights. snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
snr_loss_weights = base_weights
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
snr_loss_weights[snr == 0] = 1.0
loss = loss * snr_loss_weights loss = loss * snr_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -781,25 +781,13 @@ def main(): ...@@ -781,25 +781,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
base_weight = ( if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -631,25 +631,13 @@ def main(): ...@@ -631,25 +631,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
base_weight = ( if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -664,25 +664,13 @@ def main(): ...@@ -664,25 +664,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
base_weight = ( if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -811,25 +811,13 @@ def main(): ...@@ -811,25 +811,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
base_weight = ( if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -848,24 +848,13 @@ def main(): ...@@ -848,24 +848,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
base_weight = ( if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample prediction use the base weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -929,25 +929,13 @@ def main(): ...@@ -929,25 +929,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
base_weight = ( if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -759,25 +759,13 @@ def main(): ...@@ -759,25 +759,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
base_weight = ( if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -1050,25 +1050,13 @@ def main(args): ...@@ -1050,25 +1050,13 @@ def main(args):
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
base_weight = ( if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -1067,25 +1067,13 @@ def main(args): ...@@ -1067,25 +1067,13 @@ def main(args):
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
base_weight = ( if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
mse_loss_weights[snr == 0] = 1.0
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean() loss = loss.mean()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment