Unverified Commit c1010662 authored by Yingtian Liu's avatar Yingtian Liu Committed by GitHub
Browse files

Correct SNR weighted loss in v-prediction case by only adding 1 to SNR on the denominator (#6307)



* fix minsnr implementation for v-prediction case

* format code

* always compute snr when snr_gamma is specified

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent d4c7ab7b
...@@ -907,10 +907,12 @@ def main(): ...@@ -907,10 +907,12 @@ def main():
if args.snr_gamma is not None: if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps)) snr = jnp.array(compute_snr(timesteps))
if noise_scheduler.config.prediction_type == "v_prediction": snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma)
# Velocity objective requires that we add one to SNR values before we divide by them. if noise_scheduler.config.prediction_type == "epsilon":
snr = snr + 1 snr_loss_weights = snr_loss_weights / snr
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr elif noise_scheduler.config.prediction_type == "v_prediction":
snr_loss_weights = snr_loss_weights / (snr + 1)
loss = loss * snr_loss_weights loss = loss * snr_loss_weights
loss = loss.mean() loss = loss.mean()
......
...@@ -781,12 +781,13 @@ def main(): ...@@ -781,12 +781,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
# Velocity objective requires that we add one to SNR values before we divide by them. dim=1
snr = snr + 1 )[0]
mse_loss_weights = ( if noise_scheduler.config.prediction_type == "epsilon":
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr mse_loss_weights = mse_loss_weights / snr
) elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
......
...@@ -631,12 +631,13 @@ def main(): ...@@ -631,12 +631,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
# Velocity objective requires that we add one to SNR values before we divide by them. dim=1
snr = snr + 1 )[0]
mse_loss_weights = ( if noise_scheduler.config.prediction_type == "epsilon":
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr mse_loss_weights = mse_loss_weights / snr
) elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
......
...@@ -664,12 +664,13 @@ def main(): ...@@ -664,12 +664,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
# Velocity objective requires that we add one to SNR values before we divide by them. dim=1
snr = snr + 1 )[0]
mse_loss_weights = ( if noise_scheduler.config.prediction_type == "epsilon":
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr mse_loss_weights = mse_loss_weights / snr
) elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
......
...@@ -811,12 +811,13 @@ def main(): ...@@ -811,12 +811,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
# Velocity objective requires that we add one to SNR values before we divide by them. dim=1
snr = snr + 1 )[0]
mse_loss_weights = ( if noise_scheduler.config.prediction_type == "epsilon":
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr mse_loss_weights = mse_loss_weights / snr
) elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
......
...@@ -848,12 +848,13 @@ def main(): ...@@ -848,12 +848,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
# Velocity objective requires that we add one to SNR values before we divide by them. dim=1
snr = snr + 1 )[0]
mse_loss_weights = ( if noise_scheduler.config.prediction_type == "epsilon":
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr mse_loss_weights = mse_loss_weights / snr
) elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
......
...@@ -943,12 +943,13 @@ def main(): ...@@ -943,12 +943,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
# Velocity objective requires that we add one to SNR values before we divide by them. dim=1
snr = snr + 1 )[0]
mse_loss_weights = ( if noise_scheduler.config.prediction_type == "epsilon":
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr mse_loss_weights = mse_loss_weights / snr
) elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
......
...@@ -759,12 +759,13 @@ def main(): ...@@ -759,12 +759,13 @@ def main():
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
# Velocity objective requires that we add one to SNR values before we divide by them. dim=1
snr = snr + 1 )[0]
mse_loss_weights = ( if noise_scheduler.config.prediction_type == "epsilon":
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr mse_loss_weights = mse_loss_weights / snr
) elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
......
...@@ -1062,12 +1062,13 @@ def main(args): ...@@ -1062,12 +1062,13 @@ def main(args):
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
# Velocity objective requires that we add one to SNR values before we divide by them. dim=1
snr = snr + 1 )[0]
mse_loss_weights = ( if noise_scheduler.config.prediction_type == "epsilon":
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr mse_loss_weights = mse_loss_weights / snr
) elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
......
...@@ -1087,12 +1087,13 @@ def main(args): ...@@ -1087,12 +1087,13 @@ def main(args):
# Since we predict the noise instead of x_0, the original formulation is slightly changed. # Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper. # This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps) snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
# Velocity objective requires that we add one to SNR values before we divide by them. dim=1
snr = snr + 1 )[0]
mse_loss_weights = ( if noise_scheduler.config.prediction_type == "epsilon":
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr mse_loss_weights = mse_loss_weights / snr
) elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment