"vscode:/vscode.git/clone" did not exist on "1cf7933ea234b9aa0ba5b13fbe60740fa855e838"
Unverified Commit 26e80e01 authored by Ethan Smith's avatar Ethan Smith Committed by GitHub
Browse files

fix min-snr implementation (#8466)

* fix min-snr implementation

https://github.com/kohya-ss/sd-scripts/blob/main/library/custom_train_functions.py#L66



* Update train_dreambooth.py

fix variable name mse_loss_weights

* fix divisor

* make style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 914a585b
......@@ -1300,16 +1300,17 @@ def main(args):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
divisor = snr + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
divisor = snr
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / divisor
)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment