Unverified Commit 4edde134 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[SD3 training] refactor the density and weighting utilities. (#8591)

refactor the density and weighting utilities.
parent 074a7cc3
......@@ -53,7 +53,11 @@ from diffusers import (
StableDiffusion3Pipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.training_utils import (
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
......@@ -473,11 +477,20 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"]
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument("--logit_mean", type=float, default=0.0)
parser.add_argument("--logit_std", type=float, default=1.0)
parser.add_argument("--mode_scale", type=float, default=1.29)
parser.add_argument(
"--optimizer",
type=str,
......@@ -1477,16 +1490,13 @@ def main(args):
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
if args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
u = torch.rand(size=(bsz,), device="cpu")
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(bsz,), device="cpu")
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
......@@ -1507,19 +1517,11 @@ def main(args):
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
if args.weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise
# flow matching loss
target = model_input
if args.with_prior_preservation:
......
......@@ -51,6 +51,7 @@ from diffusers import (
StableDiffusion3Pipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
from diffusers.utils import (
check_min_version,
is_wandb_available,
......@@ -471,11 +472,20 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"]
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument("--logit_mean", type=float, default=0.0)
parser.add_argument("--logit_std", type=float, default=1.0)
parser.add_argument("--mode_scale", type=float, default=1.29)
parser.add_argument(
"--optimizer",
type=str,
......@@ -1541,16 +1551,13 @@ def main(args):
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
if args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
u = torch.rand(size=(bsz,), device="cpu")
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(bsz,), device="cpu")
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
......@@ -1587,16 +1594,9 @@ def main(args):
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
if args.weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise
# flow matching loss
target = model_input
if args.with_prior_preservation:
......
import contextlib
import copy
import math
import random
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
......@@ -220,6 +221,44 @@ def _set_state_dict_into_text_encoder(
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment