"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fda1531d8a142fea42cef60aefcd2443286bb9a1"
Unverified Commit 6946facf authored by Rafie Walker's avatar Rafie Walker Committed by GitHub
Browse files

Implement SD3 loss weighting (#8528)



* Add lognorm and cosmap weighting

* Implement mode sampling

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* keep timestamp sampling fully on cpu

---------
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 130dd936
...@@ -1462,7 +1462,18 @@ def main(args): ...@@ -1462,7 +1462,18 @@ def main(args):
bsz = model_input.shape[0] bsz = model_input.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,)) # for weighting schemes where we sample timesteps non-uniformly
if args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
u = torch.rand(size=(bsz,), device="cpu")
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(bsz,), device="cpu")
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching. # Add noise according to flow matching.
...@@ -1483,16 +1494,15 @@ def main(args): ...@@ -1483,16 +1494,15 @@ def main(args):
model_pred = model_pred * (-sigmas) + noisy_model_input model_pred = model_pred * (-sigmas) + noisy_model_input
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :) # TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
if args.weighting_scheme == "sigma_sqrt": if args.weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float() weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "logit_normal": elif args.weighting_scheme == "cosmap":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). bot = 1 - 2 * sigmas + 2 * sigmas**2
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) weighting = 2 / (math.pi * bot)
weighting = torch.nn.functional.sigmoid(u) else:
elif args.weighting_scheme == "mode": weighting = torch.ones_like(sigmas)
# See sec 3.1 in the SD3 paper (20).
u = torch.rand(size=(bsz,), device=accelerator.device)
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
# simplified flow matching aka 0-rectified flow matching loss # simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise # target = model_input - noise
......
...@@ -1526,7 +1526,18 @@ def main(args): ...@@ -1526,7 +1526,18 @@ def main(args):
bsz = model_input.shape[0] bsz = model_input.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,)) # for weighting schemes where we sample timesteps non-uniformly
if args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
u = torch.rand(size=(bsz,), device="cpu")
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(bsz,), device="cpu")
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching. # Add noise according to flow matching.
...@@ -1560,18 +1571,15 @@ def main(args): ...@@ -1560,18 +1571,15 @@ def main(args):
# Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs. # Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :) # and instead post-weight the loss
if args.weighting_scheme == "sigma_sqrt": if args.weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float() weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "logit_normal": elif args.weighting_scheme == "cosmap":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). bot = 1 - 2 * sigmas + 2 * sigmas**2
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) weighting = 2 / (math.pi * bot)
weighting = torch.nn.functional.sigmoid(u) else:
elif args.weighting_scheme == "mode": weighting = torch.ones_like(sigmas)
# See sec 3.1 in the SD3 paper (20).
u = torch.rand(size=(bsz,), device=accelerator.device)
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
# simplified flow matching aka 0-rectified flow matching loss # simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise # target = model_input - noise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment