Unverified Commit 4edde134 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[SD3 training] refactor the density and weighting utilities. (#8591)

refactor the density and weighting utilities.
parent 074a7cc3
...@@ -53,7 +53,11 @@ from diffusers import ( ...@@ -53,7 +53,11 @@ from diffusers import (
StableDiffusion3Pipeline, StableDiffusion3Pipeline,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params from diffusers.training_utils import (
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
convert_unet_state_dict_to_peft, convert_unet_state_dict_to_peft,
...@@ -473,11 +477,20 @@ def parse_args(input_args=None): ...@@ -473,11 +477,20 @@ def parse_args(input_args=None):
), ),
) )
parser.add_argument( parser.add_argument(
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"] "--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
) )
parser.add_argument("--logit_mean", type=float, default=0.0)
parser.add_argument("--logit_std", type=float, default=1.0)
parser.add_argument("--mode_scale", type=float, default=1.29)
parser.add_argument( parser.add_argument(
"--optimizer", "--optimizer",
type=str, type=str,
...@@ -1477,16 +1490,13 @@ def main(args): ...@@ -1477,16 +1490,13 @@ def main(args):
# Sample a random timestep for each image # Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly # for weighting schemes where we sample timesteps non-uniformly
if args.weighting_scheme == "logit_normal": u = compute_density_for_timestep_sampling(
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). weighting_scheme=args.weighting_scheme,
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu") batch_size=bsz,
u = torch.nn.functional.sigmoid(u) logit_mean=args.logit_mean,
elif args.weighting_scheme == "mode": logit_std=args.logit_std,
u = torch.rand(size=(bsz,), device="cpu") mode_scale=args.mode_scale,
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) )
else:
u = torch.rand(size=(bsz,), device="cpu")
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
...@@ -1507,19 +1517,11 @@ def main(args): ...@@ -1507,19 +1517,11 @@ def main(args):
# Preconditioning of the model outputs. # Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input model_pred = model_pred * (-sigmas) + noisy_model_input
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
# these weighting schemes use a uniform timestep sampling # these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss # and instead post-weight the loss
if args.weighting_scheme == "sigma_sqrt": weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
# simplified flow matching aka 0-rectified flow matching loss # flow matching loss
# target = model_input - noise
target = model_input target = model_input
if args.with_prior_preservation: if args.with_prior_preservation:
......
...@@ -51,6 +51,7 @@ from diffusers import ( ...@@ -51,6 +51,7 @@ from diffusers import (
StableDiffusion3Pipeline, StableDiffusion3Pipeline,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
is_wandb_available, is_wandb_available,
...@@ -471,11 +472,20 @@ def parse_args(input_args=None): ...@@ -471,11 +472,20 @@ def parse_args(input_args=None):
), ),
) )
parser.add_argument( parser.add_argument(
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"] "--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
) )
parser.add_argument("--logit_mean", type=float, default=0.0)
parser.add_argument("--logit_std", type=float, default=1.0)
parser.add_argument("--mode_scale", type=float, default=1.29)
parser.add_argument( parser.add_argument(
"--optimizer", "--optimizer",
type=str, type=str,
...@@ -1541,16 +1551,13 @@ def main(args): ...@@ -1541,16 +1551,13 @@ def main(args):
# Sample a random timestep for each image # Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly # for weighting schemes where we sample timesteps non-uniformly
if args.weighting_scheme == "logit_normal": u = compute_density_for_timestep_sampling(
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). weighting_scheme=args.weighting_scheme,
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu") batch_size=bsz,
u = torch.nn.functional.sigmoid(u) logit_mean=args.logit_mean,
elif args.weighting_scheme == "mode": logit_std=args.logit_std,
u = torch.rand(size=(bsz,), device="cpu") mode_scale=args.mode_scale,
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) )
else:
u = torch.rand(size=(bsz,), device="cpu")
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
...@@ -1587,16 +1594,9 @@ def main(args): ...@@ -1587,16 +1594,9 @@ def main(args):
model_pred = model_pred * (-sigmas) + noisy_model_input model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling # these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss # and instead post-weight the loss
if args.weighting_scheme == "sigma_sqrt": weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
# simplified flow matching aka 0-rectified flow matching loss # flow matching loss
# target = model_input - noise
target = model_input target = model_input
if args.with_prior_preservation: if args.with_prior_preservation:
......
import contextlib import contextlib
import copy import copy
import math
import random import random
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
...@@ -220,6 +221,44 @@ def _set_state_dict_into_text_encoder( ...@@ -220,6 +221,44 @@ def _set_state_dict_into_text_encoder(
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel: class EMAModel:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment