Unverified Commit b4077af2 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[bug fix] using snr gamma and prior preservation loss in the dreambooth lora...


[bug fix] using snr gamma and prior preservation loss in the dreambooth lora sdxl training scripts (#6356)

* change timesteps used to calculate snr when --with_prior_preservation is enabled

* change timesteps used to calculate snr when --with_prior_preservation is enabled (canonical script)

* style

* revert canonical script to before snr gamma change

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 9f2bff50
......@@ -1819,9 +1819,17 @@ def main(args):
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
if args.with_prior_preservation:
# if we're using prior preservation, we calc snr for instance loss only -
# and hence only need timesteps corresponding to instance images
snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0)
else:
snr_timesteps = timesteps
snr = compute_snr(noise_scheduler, snr_timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment