Unverified Commit 24563ca6 authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

SNR gamma fixes for v_prediction training (#5106)


Co-authored-by: default avatarbghira <bghira@users.github.com>
parent 914586f5
...@@ -908,6 +908,9 @@ def main(): ...@@ -908,6 +908,9 @@ def main():
if args.snr_gamma is not None: if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps)) snr = jnp.array(compute_snr(timesteps))
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
snr_loss_weights = snr_loss_weights + 1
loss = loss * snr_loss_weights loss = loss * snr_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -875,6 +875,9 @@ def main(): ...@@ -875,6 +875,9 @@ def main():
mse_loss_weights = ( mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -955,6 +955,9 @@ def main(): ...@@ -955,6 +955,9 @@ def main():
mse_loss_weights = ( mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -786,6 +786,9 @@ def main(): ...@@ -786,6 +786,9 @@ def main():
mse_loss_weights = ( mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
...@@ -1075,6 +1075,9 @@ def main(args): ...@@ -1075,6 +1075,9 @@ def main(args):
mse_loss_weights = ( mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
) )
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and # We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights. # rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss. # Finally, we take the mean of the rebalanced loss.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment